diff --git a/README.md b/README.md index b9c7ea2..2dedcd5 100644 --- a/README.md +++ b/README.md @@ -2,87 +2,6 @@ 运行于【海光 DCU】系列算力卡的【文本生成】引擎,基于 vLLM 引擎进行架构特别适配优化,支持 Qwen、DeepSeek、Llama 等最新开源模型。 -因具体模型之间的启动方式和具体镜像会有略微差别,请详细查看 `/enginex` 目录下各个支持模型的启动测试方式。 +源镜像:harbor.sourcefind.cn:5443/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.2-1226-das1.7-py3.10-20251226 -## 可支持模型列表 -可在项目文件夹 `/enginex` 下查看具体可支持模型文件的运行方式。 - -支持模型列表: -- jinaai/jina-embeddings-v3 -- deepseek-ai/DeepSeek-R1 -- Qwen/QwQ-32B -- deepseek-ai/DeepSeek-V3 -- deepseek-ai/DeepSeek-V3.1 -- LLaMA_Fastchat_pytorch -- Qwen/Qwen3-30B-A3B -- Qwen-7B_fastllm -- ChatGLM-6B_fastllm -- ZhipuAI/ChatGLM-6B -- Shanghai_AI_Laboratory/internlm-chat-7b -- ZhipuAI/glm-4v-9b -- ZhipuAI/GLM-4-9B-0414 -- deepseek-ai/DeepSeek-Coder-V2-Base -- openai-community/gpt2 -- ZhipuAI/chatglm2-6b -- Qwen/Qwen-7B-Chat -- baichuan-inc/Baichuan2-13B-Chat -- ZhipuAI/chatglm3-6b -- deepseek-ai/DeepSeek-V2 -- Qwen/Qwen2.5-Omni-7B -- deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B -- deepseek-ai/DeepSeek-R1-Distill-Qwen-7B -- deepseek-ai/DeepSeek-R1-Distill-Llama-8B -- deepseek-ai/DeepSeek-R1-Distill-Qwen-14B -- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B -- deepseek-ai/DeepSeek-R1-Distill-Llama-70B -- LLM-Research/Meta-Llama-3-8B-Instruct -- Qwen/Qwen1.5-14B-Chat -- Qwen/Qwen2-7B -- Qwen/Qwen3-Embedding-0.6B -- baichuan-inc/baichuan-7B -- openai-community/gpt2 -- gaodema/GME-Qwen2-VL -- OpenBMB/MiniCPM3-4B -- ZhipuAI/glm-10b-chinese -- 01ai/Yi-6B-Chat -- 01ai/Yi-34B-Chat -- ZhipuAI/glm-4-9b-chat -- deepseek-ai/DeepSeek-OCR -- Qwen/Qwen2.5-Coder-0.5B-Instruct -- Qwen/Qwen2.5-Coder-1.5B-Instruct -- Qwen/Qwen2.5-Coder-3B-Instruct -- Qwen/Qwen2.5-Coder-7B-Instruct -- Qwen/Qwen2.5-Coder-14B-Instruct -- Qwen/Qwen2.5-Coder-0.5B -- Qwen/Qwen2.5-Coder-1.5B -- Qwen/Qwen2.5-Coder-3B -- Qwen/Qwen2.5-Coder-7B -- Qwen/Qwen2.5-Coder-14B -- Qwen/Qwen2.5-Coder-32B -- deepseek-ai/DeepSeek-V3.2-Exp -- ZhipuAI/GLM-4.1V-9B-Thinking -- ZhipuAI/GLM-4.1V-9B-Base -- Shanghai_AI_Laboratory/internlm2_5-7b -- Shanghai_AI_Laboratory/internlm2-chat-20b -- Shanghai_AI_Laboratory/internlm2-7b -- Shanghai_AI_Laboratory/internlm2_5-20b -- TeleAI/telechat-7B -- TeleAI/TeleChat-12B-v2 -- OpenBMB/MiniCPM-2B-dpo-bf16 -- LLM-Research/Phi-4-multimodal-instruct -- LLM-Research/Mistral-7B-Instruct-v0.3 -- Shanghai_AI_Laboratory/internlm2_5-7b-chat -- shakechen/Llama-2-7b-hf -- Qwen/Qwen2-Audio-7B-Instruct -- AI-ModelScope/gemma-2-2b -- AI-ModelScope/falcon-7b-instruct -- Duxiaoman-DI/XuanYuan-13B-Chat -- ZhipuAI/GLM-4.6 -- LLM-Research/Codestral-22B-v0.1 -- facebook/llm-compiler-7b -- 01ai/Yi-1.5-6B-Chat -- FreedomIntelligence/HuatuoGPT-o1-8B -- ZhipuAI/GLM-Z1-32B-0414 -- Salesforce/Llama-xLAM-2-8b-fc-r -- Qwen/Qwen3-235B-A22B -- Qwen/Qwen3-Coder-480B-A35B-Instruct \ No newline at end of file +版本:0.9.2 diff --git a/enginex/Baichuan2-13B-Chat.md b/enginex/Baichuan2-13B-Chat.md deleted file mode 100644 index ccfa515..0000000 --- a/enginex/Baichuan2-13B-Chat.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -# 推荐使用docker方式运行,提供拉取的docker镜像: -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -docker run -dit --shm-size 80g --network=host --name=baichuan2 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 /bin/bash -docker exec -it baichuan2 /bin/bash -# 安装docker中没有的依赖: -pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/ChatGLM-6B.md b/enginex/ChatGLM-6B.md deleted file mode 100644 index 53594ad..0000000 --- a/enginex/ChatGLM-6B.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -docker run -dit --network=host --name=chatglm --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 -v /opt/hyhal/:/opt/hyhal/:ro git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 /usr/sbin/init -docker exec -it chatglm /bin/bash -pip install transformers==4.28.0 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -pip install accelerate sentencepiece mdtex2html gradio rouge_chinese nltk jieba datasets protobuf peft pydantic==1.10.9 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/ChatGLM-6B_fastllm.md b/enginex/ChatGLM-6B_fastllm.md deleted file mode 100644 index d326836..0000000 --- a/enginex/ChatGLM-6B_fastllm.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -推荐使用docker方式运行,提供拉取的docker镜像: -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/custom:glm-ft-v1.0 -# 自定义容器名 -# 当前工程所在路径 -docker run -it --name= -v :/work -w /work --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --cap-add=SYS_PTRACE --shm-size=16G --group-add 39 git.modelhub.org.cn:9443/enginex-hygon/custom:glm-ft-v1.0 /bin/bash -``` diff --git a/enginex/ChatGLM-6B_pytorch.md b/enginex/ChatGLM-6B_pytorch.md deleted file mode 100644 index 5808aad..0000000 --- a/enginex/ChatGLM-6B_pytorch.md +++ /dev/null @@ -1,12 +0,0 @@ -# 运行方式 - -推荐使用docker方式运行,提供拉取的docker镜像: -```python -# 推荐使用docker方式运行,提供拉取的docker镜像: -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -# 进入docker,安装docker中没有的依赖: -docker run -dit --network=host --name=chatglm --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 -v /opt/hyhal/:/opt/hyhal/:ro image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310 /usr/sbin/init -docker exec -it chatglm /bin/bash -pip install transformers==4.28.0 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -pip install accelerate sentencepiece mdtex2html gradio rouge_chinese nltk jieba datasets protobuf peft pydantic==1.10.9 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/ChatGLM3-6B.md b/enginex/ChatGLM3-6B.md deleted file mode 100644 index 785561c..0000000 --- a/enginex/ChatGLM3-6B.md +++ /dev/null @@ -1,14 +0,0 @@ -# 运行方式 - -推荐使用docker方式运行,提供拉取的docker镜像: -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -``` -进入docker,安装docker中没有的依赖: -```python -docker run -dit --network=host --name=chatglm3 --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G -v /opt/hyhal/:/opt/hyhal/:ro --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -docker exec -it chatglm3 /bin/bash -pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -cd finetune_demo -pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/Codestral-22B-v0.1.md b/enginex/Codestral-22B-v0.1.md deleted file mode 100644 index e47042b..0000000 --- a/enginex/Codestral-22B-v0.1.md +++ /dev/null @@ -1,11 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/codestral_pytorch -pip install -r requirements.txt -pip install -U huggingface_hub hf_transfer -export HF_ENDPOINT=https://hf-mirror.com -``` diff --git a/enginex/DeepSeek-Coder-V2-Base.md b/enginex/DeepSeek-Coder-V2-Base.md deleted file mode 100644 index 1f35f64..0000000 --- a/enginex/DeepSeek-Coder-V2-Base.md +++ /dev/null @@ -1,11 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/deepseek-coder-v2_pytorch -pip install -r requirements.txt -pip install -U huggingface_hub hf_transfer -export HF_ENDPOINT=https://hf-mirror.com -``` diff --git a/enginex/DeepSeek-OCR.md b/enginex/DeepSeek-OCR.md deleted file mode 100644 index fb6de17..0000000 --- a/enginex/DeepSeek-OCR.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/vllm:0.8.5-ubuntu22.04-dtk25.04.1-rc5-das1.6-py3.10-20250724 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/deepseek-ocr_pytorch - -``` diff --git a/enginex/DeepSeek-R1-Distill.md b/enginex/DeepSeek-R1-Distill.md deleted file mode 100644 index 51f4b65..0000000 --- a/enginex/DeepSeek-R1-Distill.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 - -docker run --shm-size 500g --network=host --name=dpskv3 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl - -pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl -``` \ No newline at end of file diff --git a/enginex/DeepSeek-R1.md b/enginex/DeepSeek-R1.md deleted file mode 100644 index f09436b..0000000 --- a/enginex/DeepSeek-R1.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 -```python -docker git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 - -docker run --shm-size 500g --network=host --name=dpskr1 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -cd inference -pip install -r requirements.txt -``` \ No newline at end of file diff --git a/enginex/DeepSeek-R1_ollama.md b/enginex/DeepSeek-R1_ollama.md deleted file mode 100644 index b30a041..0000000 --- a/enginex/DeepSeek-R1_ollama.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 -```python -docker git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-py3.10-dtk24.04.3-ubuntu20.04 - -docker run --shm-size 500g --network=host --name=dpskr1 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -cd inference -pip install -r requirements.txt -``` \ No newline at end of file diff --git a/enginex/DeepSeek-V2.md b/enginex/DeepSeek-V2.md deleted file mode 100644 index ddb1131..0000000 --- a/enginex/DeepSeek-V2.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/deepseek-v2_pytorch -pip install -r requirements.txt -``` \ No newline at end of file diff --git a/enginex/DeepSeek-V3.1.md b/enginex/DeepSeek-V3.1.md deleted file mode 100644 index e3923a6..0000000 --- a/enginex/DeepSeek-V3.1.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas101839-0811-das1.6-py3.10-20250812-beta - -docker run -it --name {docker_name} --device=/dev/kfd --privileged --network=host --device=/dev/dri --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /your_code_path:/your_code_path -v /opt/hyhal:/opt/hyhal:ro -v /module/DeepSeek-V3.1:/your_model_path/DeepSeek-V3.1 --group-add video --shm-size 64G {imageID} bash - -cd /your_code_path/deepseek-v3.1_vllm -``` \ No newline at end of file diff --git a/enginex/DeepSeek-V3.2-Exp.md b/enginex/DeepSeek-V3.2-Exp.md deleted file mode 100644 index dd086d0..0000000 --- a/enginex/DeepSeek-V3.2-Exp.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2-ds3.2 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/deepseek-v3.2-exp_vllm -``` diff --git a/enginex/DeepSeek-V3.md b/enginex/DeepSeek-V3.md deleted file mode 100644 index 2097a0f..0000000 --- a/enginex/DeepSeek-V3.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 - -docker run --shm-size 500g --network=host --name=dpskv3 --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -cd inference -pip install -r requirements.txt -``` \ No newline at end of file diff --git a/enginex/GLM-4-9B-0414.md b/enginex/GLM-4-9B-0414.md deleted file mode 100644 index 9c46293..0000000 --- a/enginex/GLM-4-9B-0414.md +++ /dev/null @@ -1,12 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=64G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name glm-4v bash - -cd /path/your_code_data/ - -pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -#开发者社区下载bitsandbytes -pip install bitsandbytes-0.42.0+das1.1.gitce85679.abi1.dtk2404.torch2.1.0-py3-none-any.whl -``` diff --git a/enginex/GLM-4.1V.md b/enginex/GLM-4.1V.md deleted file mode 100644 index f60d112..0000000 --- a/enginex/GLM-4.1V.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/vllm:0.8.5-ubuntu22.04-dtk25.04.1-rc5-das1.6-py3.10-20250711 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/glm-4.1v_pytorch -pip install transformers==4.53.2 -``` diff --git a/enginex/GLM-4.6.md b/enginex/GLM-4.6.md deleted file mode 100644 index 2c7b539..0000000 --- a/enginex/GLM-4.6.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/glm-4.6_vllm -``` diff --git a/enginex/GLM-Z1-32B-0414.md b/enginex/GLM-Z1-32B-0414.md deleted file mode 100644 index 7375c55..0000000 --- a/enginex/GLM-Z1-32B-0414.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/glm-z1_pytorch -pip install transformers>=4.51.3 -``` diff --git a/enginex/GME-Qwen2-VL.md b/enginex/GME-Qwen2-VL.md deleted file mode 100644 index b3456ab..0000000 --- a/enginex/GME-Qwen2-VL.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10 -docker run --shm-size 100g --network=host --name=gme --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -pip install -r requirements.txt -``` diff --git a/enginex/HuatuoGPT-o1-8B.md b/enginex/HuatuoGPT-o1-8B.md deleted file mode 100644 index 80f7dbf..0000000 --- a/enginex/HuatuoGPT-o1-8B.md +++ /dev/null @@ -1,15 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 - -docker run --shm-size 50g --network=host --name=huatuo --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -pip install -r requirements.txt - -pip uninstall vllm - -pip install https://download.sourcefind.cn:65024/directlink/4/lmslim/DAS1.3/lmslim-0.1.2+das.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl - -pip install https://download.sourcefind.cn:65024/directlink/4/vllm/DAS1.3/vllm-0.6.2+das.opt1.dtk24043-cp310-cp310-manylinux_2_28_x86_64.whl -``` diff --git a/enginex/Internlm.md b/enginex/Internlm.md deleted file mode 100644 index 804c0a5..0000000 --- a/enginex/Internlm.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 -# 用上面拉取docker镜像的ID替换 -# 主机端路径 -# 容器映射路径 -# 若要在主机端和容器端映射端口需要删除--network host参数 -docker run -it --name internlm_vllm --privileged --shm-size=64G --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal -v : /bin/bash -``` diff --git a/enginex/LLaMA_Fastchat_pytorch.md b/enginex/LLaMA_Fastchat_pytorch.md deleted file mode 100644 index 2d58598..0000000 --- a/enginex/LLaMA_Fastchat_pytorch.md +++ /dev/null @@ -1,16 +0,0 @@ -# 运行方式 -```python -拉取镜像: -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -创建并启动容器: -docker run --shm-size 64g --network=host --name=llama_fastchat --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal:ro -v : -it bash - -cp -r mpirun/* ./ -cd FastChat-main -pip3 install -e . -cd ../transformers-main -pip3 install -e . -pip3 uninstall wandb -pip3 install mpi4py -cd .. -``` \ No newline at end of file diff --git a/enginex/Llama-2-7b-hf.md b/enginex/Llama-2-7b-hf.md deleted file mode 100644 index f089253..0000000 --- a/enginex/Llama-2-7b-hf.md +++ /dev/null @@ -1,7 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 - -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash -``` diff --git a/enginex/Llama-xLAM-2-8b-fc-r.md b/enginex/Llama-xLAM-2-8b-fc-r.md deleted file mode 100644 index 77c76cf..0000000 --- a/enginex/Llama-xLAM-2-8b-fc-r.md +++ /dev/null @@ -1,11 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10 - -docker run --shm-size 100g --network=host --name=wan --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it bash - -pip install -r requirements.txt - -pip install -e . -``` diff --git a/enginex/Meta-Llama-3-8B-Instruct.md b/enginex/Meta-Llama-3-8B-Instruct.md deleted file mode 100644 index 203999d..0000000 --- a/enginex/Meta-Llama-3-8B-Instruct.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -# 推荐使用docker方式运行,提供拉取的docker镜像: -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/llama3_pytorch -pip install -e . -``` diff --git a/enginex/MiniCPM-2B-dpo-bf16.md b/enginex/MiniCPM-2B-dpo-bf16.md deleted file mode 100644 index 2bc4cb8..0000000 --- a/enginex/MiniCPM-2B-dpo-bf16.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -# 为以上拉取的docker的镜像ID替换 -docker run -it --shm-size=32G -v $PWD/MiniCPM:/home/MiniCPM -v /opt/hyhal:/opt/hyhal --network=host --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name minicpm bash -cd /home/MiniCPM -pip install -r finetune/requirements.txt # finetune/requirements.txt -``` diff --git a/enginex/MiniCPM3-4B.md b/enginex/MiniCPM3-4B.md deleted file mode 100644 index 2bc4cb8..0000000 --- a/enginex/MiniCPM3-4B.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -# 为以上拉取的docker的镜像ID替换 -docker run -it --shm-size=32G -v $PWD/MiniCPM:/home/MiniCPM -v /opt/hyhal:/opt/hyhal --network=host --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name minicpm bash -cd /home/MiniCPM -pip install -r finetune/requirements.txt # finetune/requirements.txt -``` diff --git a/enginex/Mistral-7B-Instruct-v0.3.md b/enginex/Mistral-7B-Instruct-v0.3.md deleted file mode 100644 index 2bc5cf0..0000000 --- a/enginex/Mistral-7B-Instruct-v0.3.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/custom:vllm0.8.5-ubuntu22.04-dtk25.04-rc7-das1.5-py3.10-20250521-fixpy-rocblas0521-beta2 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/mistral_pytorch -``` diff --git a/enginex/Phi-4-multimodal-instruct.md b/enginex/Phi-4-multimodal-instruct.md deleted file mode 100644 index 4d6279a..0000000 --- a/enginex/Phi-4-multimodal-instruct.md +++ /dev/null @@ -1,12 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-py3.10-dtk24.04.3-ubuntu20.04 -docker run -it --shm-size=1024G -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name phi-4 bash # 为以上拉取的docker的镜像ID替换 - -git clone http://developer.sourcefind.cn/codes/modelzoo/phi-4-multimodal-instruct_pytorch.git - -cd /path/your_code_data/ - -pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple -``` diff --git a/enginex/QwQ-32B.md b/enginex/QwQ-32B.md deleted file mode 100644 index 3a58392..0000000 --- a/enginex/QwQ-32B.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10 -# 为以上拉取的docker的镜像ID替换,本镜像为:dee41741fb40 -docker run -it --shm-size=64G --network host -v $PWD/QwQ-32B:/home/QwQ-32B -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name qwq bash -cd /home/QwQ-32B -pip install -r requirements.txt -``` diff --git a/enginex/Qwen-7B-Chat.md b/enginex/Qwen-7B-Chat.md deleted file mode 100644 index 606325b..0000000 --- a/enginex/Qwen-7B-Chat.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -# 推荐使用docker方式运行,提供拉取的docker镜像: -docker pull git.modelhub.org.cn:9443/enginex-hygon/git.modelhub.org.cn:9443/enginex-hygon/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest -# 自定义容器名 -# 当前工程所在路径 -docker run -it --name= -v :/work --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --cap-add=SYS_PTRACE --shm-size=16G --group-add 39 image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest /bin/bash -``` diff --git a/enginex/Qwen-7B_fastllm.md b/enginex/Qwen-7B_fastllm.md deleted file mode 100644 index 2b1b711..0000000 --- a/enginex/Qwen-7B_fastllm.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest -# 自定义容器名 -# 当前工程所在路径 -docker run -it --name= -v :/work --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --cap-add=SYS_PTRACE --shm-size=16G --group-add 39 git.modelhub.org.cn:9443/enginex-hygon/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest /bin/bash -``` diff --git a/enginex/Qwen1.5-14B-Chat.md b/enginex/Qwen1.5-14B-Chat.md deleted file mode 100644 index 68c6872..0000000 --- a/enginex/Qwen1.5-14B-Chat.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -# 推荐使用docker方式运行,提供拉取的docker镜像: -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -docker run -it --shm-size=1024G -v $PWD/qwen1.5-pytorch:/home/Qwen1.5-pytorch -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name Qwen1.5-pytorch bash # 为以上拉取的docker的镜像ID替换,本镜像为:ffa1f63239fc -cd /home/Qwen1.5-pytorch -pip install -r requirement.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/Qwen2-7B.md b/enginex/Qwen2-7B.md deleted file mode 100644 index 477382e..0000000 --- a/enginex/Qwen2-7B.md +++ /dev/null @@ -1,11 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -docker run -it --shm-size=1024G -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name qwen2_72B_pytorch bash # 为以上拉取的docker的镜像ID替换,本镜像为:a4dd5be0ca23 -pip install https://cancon.hpccube.com:65024/directlink/4/vllm/DAS1.1.1/vllm-0.5.0+das.opt1.3e2c63a.dtk2404.torch2.1.0-cp310-cp310-linux_x86_64.whl -cd /path/your_code_data/ -cd LLaMA-Factory -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -pip install e . -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/Qwen2-Audio-7B-Instruct.md b/enginex/Qwen2-Audio-7B-Instruct.md deleted file mode 100644 index d264a94..0000000 --- a/enginex/Qwen2-Audio-7B-Instruct.md +++ /dev/null @@ -1,15 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10 - -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=64G --privileged=true --network=host --device=/dev/kfd --device=/dev/dri/ --group-add video --name qwen2-audio bash - -cd /path/your_code_data/Qwen2-Audio/demo - -pip install -r requirements_web_demo.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com - -pip install git+https://github.com/modelscope/swift.git#egg=ms-swift[llm] - -pip install git+https://github.com/huggingface/transformers.git -``` diff --git a/enginex/Qwen2.5-Omni-7B.md b/enginex/Qwen2.5-Omni-7B.md deleted file mode 100644 index a6504d1..0000000 --- a/enginex/Qwen2.5-Omni-7B.md +++ /dev/null @@ -1,13 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10-fixpy -# 为以上拉取的docker的镜像ID替换,本镜像为:e77c15729879 -docker run -it --shm-size=64G -v $PWD/Qwen2.5-Omni:/home/Qwen2.5-Omni -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name qomni bash -cd /home/Qwen2.5-Omni -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple - -unzip f742a644ca32e65758c3adb36225aef1731bd2a8.zip -cd transformers-f742a644ca32e65758c3adb36225aef1731bd2a8 -pip install -e . # 作者限定只能使用transformers==4.50.0.dev0 -``` diff --git a/enginex/Qwen3-30B-A3B.md b/enginex/Qwen3-30B-A3B.md deleted file mode 100644 index 96cc4de..0000000 --- a/enginex/Qwen3-30B-A3B.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/vllm:0.8.5-ubuntu22.04-dtk25.04.1-rc5-das1.6-py3.10-20250724 - -docker run -it --name {docker_name} --device=/dev/kfd --privileged --network=host --device=/dev/dri --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /public/LLM-Models:/home/LLM-Models:ro -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal:ro --group-add video --shm-size 64G {imageID} bash - -cd /your_code_path/qwen3-30b-a3b_vllm -``` diff --git a/enginex/Qwen3-Embedding-0.6B.md b/enginex/Qwen3-Embedding-0.6B.md deleted file mode 100644 index 89e29cc..0000000 --- a/enginex/Qwen3-Embedding-0.6B.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/custom:vllm0.8.5-ubuntu22.04-dtk25.04-rc7-das1.5-py3.10-20250521-fixpy-rocblas0521-beta2 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/qwen3-embedding_pytorch -pip install transformers>=4.51.0 -pip install sentence-transformers>=2.7.0 -``` diff --git a/enginex/Qwen3.md b/enginex/Qwen3.md deleted file mode 100644 index 439dd30..0000000 --- a/enginex/Qwen3.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/custom:vllm0.8.4-ubuntu22.04-dtk25.04-rc7-das1.5-py3.10-20250429-dev-qwen3-only -# docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10-fixpy -# 为以上拉取的docker的镜像ID替换,本镜像为:6e12a1c4ae4d -docker run -it --shm-size=64G -v $PWD/Qwen3:/home/Qwen3 -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name qwen3 bash -cd /home/Qwen3 -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple -``` diff --git a/enginex/Qwen:Qwen3-8B.md b/enginex/Qwen:Qwen3-8B.md deleted file mode 100644 index c73f685..0000000 --- a/enginex/Qwen:Qwen3-8B.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/custom:vllm0.8.4-ubuntu22.04-dtk25.04-rc7-das1.5-py3.10-20250429-dev-qwen3-only -# 为以上拉取的docker的镜像ID替换 -docker run -it --shm-size=64G -v $PWD/Qwen3:/home/Qwen3 -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name qwen3 bash -cd /home/Qwen3 -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple -``` diff --git a/enginex/TeleChat-12B-v2.md b/enginex/TeleChat-12B-v2.md deleted file mode 100644 index 4a2259e..0000000 --- a/enginex/TeleChat-12B-v2.md +++ /dev/null @@ -1,15 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 - -创建并启动容器: -docker run --shm-size 80g --network=host --name=telechat --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal:ro -v : -it bash - -安装依赖: -cd TeleChat -pip install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ -pip install 'ms-swift[llm]' -U -i https://pypi.mirrors.ustc.edu.cn/simple/ -pip install optimum -i https://pypi.mirrors.ustc.edu.cn/simple/ -pip install auto-gptq -i https://pypi.mirrors.ustc.edu.cn/simple/ -``` diff --git a/enginex/XuanYuan-13B-Chat.md b/enginex/XuanYuan-13B-Chat.md deleted file mode 100644 index 7f2eefb..0000000 --- a/enginex/XuanYuan-13B-Chat.md +++ /dev/null @@ -1,17 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 - -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=64G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name xuanyuan bash - -cd /path/your_code_data/ - -pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com - -git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git - -cd LLaMA-Factory - -pip install -e ".[torch,metrics]" -``` diff --git a/enginex/Yi-1.5-6B-Chat.md b/enginex/Yi-1.5-6B-Chat.md deleted file mode 100644 index ef5ef0f..0000000 --- a/enginex/Yi-1.5-6B-Chat.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-py3.10-dtk24.04.3-ubuntu20.04 -docker run -it --shm-size=1024G -v : -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name Yi-1.5 bash # 为以上拉取的docker的镜像ID替换 -cd /home/Yi-1.5-pytorch -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com - -pip uninstall vllm -``` diff --git a/enginex/Yi-34B-Chat.md b/enginex/Yi-34B-Chat.md deleted file mode 100644 index 66b6d12..0000000 --- a/enginex/Yi-34B-Chat.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -# 用上面拉取docker镜像的ID替换 -# 主机端路径 -# 容器映射路径 -docker run -it --name yi --shm-size=64G --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal:ro --ulimit memlock=-1:-1 --ipc=host --network=host --group-add video -v : /bin/bash -``` diff --git a/enginex/Yi-6B-Chat.md b/enginex/Yi-6B-Chat.md deleted file mode 100644 index 66b6d12..0000000 --- a/enginex/Yi-6B-Chat.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -# 用上面拉取docker镜像的ID替换 -# 主机端路径 -# 容器映射路径 -docker run -it --name yi --shm-size=64G --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal:ro --ulimit memlock=-1:-1 --ipc=host --network=host --group-add video -v : /bin/bash -``` diff --git a/enginex/baichuan-7B.md b/enginex/baichuan-7B.md deleted file mode 100644 index 340b07b..0000000 --- a/enginex/baichuan-7B.md +++ /dev/null @@ -1,7 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -docker run -dit --network=host --name=baichuan -v /opt/hyhal:/opt/hyhal:ro --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py38-latest /bin/bash -docker exec -it baichuan /bin/bash -``` diff --git a/enginex/chatglm2-6b.md b/enginex/chatglm2-6b.md deleted file mode 100644 index ba79897..0000000 --- a/enginex/chatglm2-6b.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -docker exec -it chatglm /bin/bash -pip install transformers==4.28.0 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -pip install accelerate sentencepiece mdtex2html gradio rouge_chinese nltk jieba datasets==2.20.0 protobuf peft==0.5.0 pydantic==1.10.9 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/falcon-7b-instruct.md b/enginex/falcon-7b-instruct.md deleted file mode 100644 index dc4240a..0000000 --- a/enginex/falcon-7b-instruct.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/falcon_pytorch -pip install -r requirements.txt -``` diff --git a/enginex/gemma-2-2b.md b/enginex/gemma-2-2b.md deleted file mode 100644 index 7d45856..0000000 --- a/enginex/gemma-2-2b.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.2-py3.10 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/gemma2_pytorch -pip install -r requirements.txt -``` diff --git a/enginex/glm-10b-chinese.md b/enginex/glm-10b-chinese.md deleted file mode 100644 index 2bc4cb8..0000000 --- a/enginex/glm-10b-chinese.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 -# 为以上拉取的docker的镜像ID替换 -docker run -it --shm-size=32G -v $PWD/MiniCPM:/home/MiniCPM -v /opt/hyhal:/opt/hyhal --network=host --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name minicpm bash -cd /home/MiniCPM -pip install -r finetune/requirements.txt # finetune/requirements.txt -``` diff --git a/enginex/glm-4-9b-chat.md b/enginex/glm-4-9b-chat.md deleted file mode 100644 index bc3bd3e..0000000 --- a/enginex/glm-4-9b-chat.md +++ /dev/null @@ -1,10 +0,0 @@ -# 运行方式 - -```python -dcoker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10 -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/glm-4_pytorch -pip install -r inference/requirements.txt -pip install -r finetune/requirements.txt -``` diff --git a/enginex/glm-4v-9b.md b/enginex/glm-4v-9b.md deleted file mode 100644 index 61bfd10..0000000 --- a/enginex/glm-4v-9b.md +++ /dev/null @@ -1,13 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 - -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=64G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name glm-4v bash - -cd /path/your_code_data/ - -pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -#开发者社区下载bitsandbytes -pip install bitsandbytes-0.42.0+das1.1.gitce85679.abi1.dtk2404.torch2.1.0-py3-none-any.whl -``` diff --git a/enginex/gpt2.md b/enginex/gpt2.md deleted file mode 100644 index 925b909..0000000 --- a/enginex/gpt2.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest -docker run -dit --network=host --name=gpt2_pytorch --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest -docker exec -it gpt2_pytorch /bin/bash -pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com -``` diff --git a/enginex/internlm-chat-7b.md b/enginex/internlm-chat-7b.md deleted file mode 100644 index 166b23a..0000000 --- a/enginex/internlm-chat-7b.md +++ /dev/null @@ -1,9 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10(推荐) -# 用上面拉取docker镜像的ID替换 -# 主机端路径 -# 容器映射路径 -docker run -it --name baichuan --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal:ro --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v : /bin/bash -``` diff --git a/enginex/jina-embeddings-v3.md b/enginex/jina-embeddings-v3.md deleted file mode 100644 index 4f53058..0000000 --- a/enginex/jina-embeddings-v3.md +++ /dev/null @@ -1,8 +0,0 @@ -# 运行方式 -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/custom:vllm0.8.5-ubuntu22.04-dtk25.04-rc7-das1.5-py3.10-20250612-fixpy-rocblas0611-rc2 - -docker run -it --shm-size 200g --network=host --name {docker_name} --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro {imageID} bash - -cd /your_code_path/jina-embeddings-v3_vllm -``` \ No newline at end of file diff --git a/enginex/llm-compiler-7b.md b/enginex/llm-compiler-7b.md deleted file mode 100644 index 8b73a54..0000000 --- a/enginex/llm-compiler-7b.md +++ /dev/null @@ -1,11 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-centos7.6-dtk24.04-py310 -docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash - -cd /your_code_path/llm-compiler_pytorch -pip install -r requirements.txt -pip install -U huggingface_hub hf_transfer -export HF_ENDPOINT=https://hf-mirror.com -``` diff --git a/enginex/qwen2.5-coder.md b/enginex/qwen2.5-coder.md deleted file mode 100644 index e5e4c53..0000000 --- a/enginex/qwen2.5-coder.md +++ /dev/null @@ -1,13 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.4.1-ubuntu22.04-dtk25.04.1-py3.10 -docker run -it --name {name} --shm-size=1024G --device=/dev/kfd --device=/dev/dri/ --privileged --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ulimit memlock=-1:-1 --ipc=host --network host --group-add video -v /opt/hyhal:/opt/hyhal:ro -v {}:{} {docker_image} /bin/bash -# 修改1 {name} 需要改为自定义名称 -# 修改2 {docker_image} 需要需要创建容器的对应镜像名称 -# 修改3 -v 挂载路径到容器指定路径 -pip install -r requirements.txt -cd LLaMA-Factory -pip install -e ".[torch,metrics]" -pip install deepspeed-0.14.2+das.opt1.dtk25041-cp310-cp310-manylinux_2_28_x86_64.whl -``` diff --git a/enginex/telechat-7B.md b/enginex/telechat-7B.md deleted file mode 100644 index 4a2259e..0000000 --- a/enginex/telechat-7B.md +++ /dev/null @@ -1,15 +0,0 @@ -# 运行方式 - -```python -docker pull git.modelhub.org.cn:9443/enginex-hygon/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 - -创建并启动容器: -docker run --shm-size 80g --network=host --name=telechat --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /opt/hyhal:/opt/hyhal:ro -v : -it bash - -安装依赖: -cd TeleChat -pip install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ -pip install 'ms-swift[llm]' -U -i https://pypi.mirrors.ustc.edu.cn/simple/ -pip install optimum -i https://pypi.mirrors.ustc.edu.cn/simple/ -pip install auto-gptq -i https://pypi.mirrors.ustc.edu.cn/simple/ -``` diff --git a/vllm/_C.abi3.so b/vllm/_C.abi3.so new file mode 100755 index 0000000..40f023e Binary files /dev/null and b/vllm/_C.abi3.so differ diff --git a/vllm/__init__.py b/vllm/__init__.py new file mode 100644 index 0000000..7b90fd3 --- /dev/null +++ b/vllm/__init__.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" + +# The version.py should be independent library, and we always import the +# version library first. Such assumption is critical for some customization. +from .version import __version__, __version_tuple__ # isort:skip + +import typing + +# The environment variables override should be imported before any other +# modules to ensure that the environment variables are set before any +# other modules are imported. +import vllm.env_override # noqa: F401 + +MODULE_ATTRS = { + "AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs", + "EngineArgs": ".engine.arg_utils:EngineArgs", + "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", + "LLMEngine": ".engine.llm_engine:LLMEngine", + "LLM": ".entrypoints.llm:LLM", + "initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster", + "PromptType": ".inputs:PromptType", + "TextPrompt": ".inputs:TextPrompt", + "TokensPrompt": ".inputs:TokensPrompt", + "ModelRegistry": ".model_executor.models:ModelRegistry", + "SamplingParams": ".sampling_params:SamplingParams", + "PoolingParams": ".pooling_params:PoolingParams", + "ClassificationOutput": ".outputs:ClassificationOutput", + "ClassificationRequestOutput": ".outputs:ClassificationRequestOutput", + "CompletionOutput": ".outputs:CompletionOutput", + "EmbeddingOutput": ".outputs:EmbeddingOutput", + "EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput", + "PoolingOutput": ".outputs:PoolingOutput", + "PoolingRequestOutput": ".outputs:PoolingRequestOutput", + "RequestOutput": ".outputs:RequestOutput", + "ScoringOutput": ".outputs:ScoringOutput", + "ScoringRequestOutput": ".outputs:ScoringRequestOutput", +} + +if typing.TYPE_CHECKING: + from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.engine.llm_engine import LLMEngine + from vllm.entrypoints.llm import LLM + from vllm.executor.ray_utils import initialize_ray_cluster + from vllm.inputs import PromptType, TextPrompt, TokensPrompt + from vllm.model_executor.models import ModelRegistry + from vllm.outputs import (ClassificationOutput, + ClassificationRequestOutput, CompletionOutput, + EmbeddingOutput, EmbeddingRequestOutput, + PoolingOutput, PoolingRequestOutput, + RequestOutput, ScoringOutput, + ScoringRequestOutput) + from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams +else: + + def __getattr__(name: str) -> typing.Any: + from importlib import import_module + + if name in MODULE_ATTRS: + module_name, attr_name = MODULE_ATTRS[name].split(":") + module = import_module(module_name, __package__) + return getattr(module, attr_name) + else: + raise AttributeError( + f'module {__package__} has no attribute {name}') + + +__all__ = [ + "__version__", + "__version_tuple__", + "LLM", + "ModelRegistry", + "PromptType", + "TextPrompt", + "TokensPrompt", + "SamplingParams", + "RequestOutput", + "CompletionOutput", + "PoolingOutput", + "PoolingRequestOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", + "ClassificationOutput", + "ClassificationRequestOutput", + "ScoringOutput", + "ScoringRequestOutput", + "LLMEngine", + "EngineArgs", + "AsyncLLMEngine", + "AsyncEngineArgs", + "initialize_ray_cluster", + "PoolingParams", +] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000..d8350b6 --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,2455 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +from typing import TYPE_CHECKING, Optional, Union + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType +from vllm.utils import direct_register_custom_op +try: + from lmslim import quant_ops + from lmslim import quant_tools +except Exception: + print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") +try: + import lightop +except Exception: + print("INFO: Please install lightop if you want to infer awq of marlin.\n") + +logger = init_logger(__name__) + +if not current_platform.is_tpu() and not current_platform.is_hpu(): + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + + +supports_moe_ops = False +with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + supports_moe_ops = True + +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + torch.ops._C.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + torch.ops._C.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + +def paged_attention_v1_with_mask( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0, +) -> None: + torch.ops._C.paged_attention_v1_with_mask( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step,attn_masks, + attn_masks_stride) + + +def paged_attention_v2_with_mask( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + query_start_loc: Optional[torch.Tensor], + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0, +) -> None: + torch.ops._C.paged_attention_v2_with_mask( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step, + attn_masks, attn_masks_stride) + + +# page attention ops (opt) +def paged_attention_v1_opt( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0 +) -> None: + torch.ops._C.paged_attention_v1_opt( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) + + +def paged_attention_v2_opt( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0 +) -> None: + torch.ops._C.paged_attention_v2_opt( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +def paged_attention_v1_opt_with_mask( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0, +) -> None: + torch.ops._C.paged_attention_v1_opt_with_mask( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step, attn_masks, + attn_masks_stride) + + +def paged_attention_v2_opt_with_mask( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0, +) -> None: + torch.ops._C.paged_attention_v2_opt_with_mask( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step, + attn_masks, attn_masks_stride) + + +# page attention ops (opt) +def paged_attention_v1_opt_tc( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0 +) -> None: + torch.ops._C.paged_attention_v1_opt_tc( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +def paged_attention_v2_opt_tc( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0 +) -> None: + torch.ops._C.paged_attention_v2_opt_tc( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +# page attention ops (opt) +def paged_attention_v1_opt_tc_with_mask( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0, +) -> None: + torch.ops._C.paged_attention_v1_opt_tc_with_mask( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step, + attn_masks, attn_masks_stride) + + +def paged_attention_v2_opt_tc_with_mask( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0, +) -> None: + torch.ops._C.paged_attention_v2_opt_tc_with_mask( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step, + attn_masks, attn_masks_stride) + + +# def paged_attention_rocm( +# out: torch.Tensor, +# exp_sum: torch.Tensor, +# max_logits: torch.Tensor, +# tmp_out: torch.Tensor, +# query: torch.Tensor, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# num_kv_heads: int, +# scale: float, +# block_tables: torch.Tensor, +# seq_lens: torch.Tensor, +# query_start_loc: Optional[torch.Tensor], +# block_size: int, +# max_seq_len: int, +# alibi_slopes: Optional[torch.Tensor], +# kv_cache_dtype: str, +# k_scale: torch.Tensor, +# v_scale: torch.Tensor, +# fp8_out_scale: Optional[torch.Tensor] = None, +# ) -> None: +# torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, +# key_cache, value_cache, num_kv_heads, +# scale, block_tables, seq_lens, +# query_start_loc, block_size, max_seq_len, +# alibi_slopes, kv_cache_dtype, k_scale, +# v_scale, fp8_out_scale) + + +def mla_decode_kvcache_cpu( + out: torch.Tensor, + query: torch.Tensor, + kv_cache: torch.Tensor, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, +) -> None: + torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale, + block_tables, seq_lens) + + +# merge attn states ops +def merge_attn_states(output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None) -> None: + torch.ops._C.merge_attn_states(output, output_lse, prefix_output, + prefix_lse, suffix_output, suffix_lse) + + +def convert_vertical_slash_indexes( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, context_size, + block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +def convert_vertical_slash_indexes_mergehead( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + # [N_HEADS] : different head use different number of indices + vertical_indices_count: torch.Tensor, + slash_indices_count: torch.Tensor, + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.empty(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.empty(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes_mergehead( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor], + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: Optional[torch.Tensor], head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input + input_contiguous = input.contiguous() + torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) + + +# layer norm ops (opt) +def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.rms_norm_opt(out, input, weight, epsilon) + + +def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon) + + +def apply_repetition_penalties_torch( + logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + + +def apply_repetition_penalties_cuda( + logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask, + repetition_penalties) + + +def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor) -> None: + """Apply repetition penalties to logits in-place. + + Args: + logits: The logits tensor of shape [num_seqs, vocab_size]. + prompt_mask: A boolean tensor indicating which tokens appear in the prompt. + output_mask: A boolean tensor indicating which tokens appear in the output. + repetition_penalties: The repetition penalties of shape (num_seqs, ). + """ + if current_platform.is_cuda() and logits.is_contiguous(): + apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, + repetition_penalties) + else: + apply_repetition_penalties_torch(logits, prompt_mask, output_mask, + repetition_penalties) + + +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, + block_size, input_tokens, + sampled_token_ids, + input_positions, seq_lens, + slot_mapping, block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + + return torch.ops._C.advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + block_table_bound) + +# trans_w16 +def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor, + row:int, col:int) -> None : + torch.ops._C.trans_w16_gemm(dst,src,row,col) + + +# fused quant layer norm ops +def rms_norm_dynamic_per_token_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight, + scales, epsilon, scale_ub, + residual) + return output, scales + + +# quantization ops +# awq +def GetAWQShareWorkspaceSize()->int: + return quant_ops.GetAWQShareWorkspaceSize() + +def GetAWQShareWorkspace()->torch.Tensor: + return quant_ops.GetAWQShareWorkspace() + +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, + thx, thy) + + +# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, +# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: +# if envs.VLLM_USE_TRITON_AWQ: +# from vllm.model_executor.layers.quantization.awq_triton import ( +# awq_gemm_triton) +# return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) +# return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +def awq_gemm(input: torch.Tensor, weight: torch.Tensor, + zeros_and_scales:torch.Tensor, + m:int,n:int,k:int, + group_size:int,padding_group:int,splikspace:torch.Tensor, + splikspacesize:int) -> torch.Tensor: + return quant_ops.awq_gemm(input, + weight, + zeros_and_scales, + m, + n, + k, + group_size, + padding_group, + splikspace, + splikspacesize) + +def awq_gemm_fake(input: torch.Tensor, weight: torch.Tensor, + zeros_and_scales:torch.Tensor, + m:int,n:int,k:int, + group_size:int,padding_group:int,splikspace:torch.Tensor, + splikspacesize:int) -> torch.Tensor: + + return torch.empty((m, n), dtype=input.dtype, device=input.device) + +def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor, + group_size: int): + return quant_ops.convert_s4(qw,qz,s,group_size) + +def sz_permute(sz:torch.Tensor)-> torch.Tensor: + return quant_ops.sz_permute(sz) + +def dequant_w4_gemm_colmajor(qweight:torch.Tensor, + zeros_and_scale:torch.Tensor, + k:int, + n:int, + group_size:int + )->torch.Tensor: + return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + # return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + # b_g_idx, use_exllama, bit) + + +if hasattr(torch.ops._C, "gptq_gemm"): + + @register_fake("_C::gptq_gemm") + def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, + use_exllama: bool, bit: int) -> torch.Tensor: + return torch.empty((a.size(0), b_q_weight.size(1)), + dtype=a.dtype, + device=a.device) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + quant_ops.gptq_shuffle(q_weight, q_perm, bit) + # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) + + + +# marlin +# def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, +# b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, +# size_n: int, size_k: int) -> torch.Tensor: +# return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, +# size_n, size_k) + + +# # marlin_24 +# def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, +# b_meta: torch.Tensor, b_scales: torch.Tensor, +# workspace: torch.Tensor, b_q_type: ScalarType, +# size_m: int, size_n: int, size_k: int) -> torch.Tensor: +# return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, +# workspace, b_q_type.id, size_m, +# size_n, size_k) + + +# if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): + +# @register_fake("_C::gptq_marlin_24_gemm") +# def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, +# b_meta: torch.Tensor, b_scales: torch.Tensor, +# workspace: torch.Tensor, +# b_q_type: ScalarType, size_m: torch.SymInt, +# size_n: torch.SymInt, +# size_k: torch.SymInt) -> torch.Tensor: +# return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + +# @register_fake("_C::gptq_marlin_gemm") +# def _gptq_marlin_gemm_fake(a: torch.Tensor, +# c: Optional[torch.Tensor], +# b_q_weight: torch.Tensor, +# b_scales: torch.Tensor, +# global_scale: Optional[torch.Tensor], +# b_zeros: Optional[torch.Tensor], +# g_idx: Optional[torch.Tensor], +# perm: Optional[torch.Tensor], +# workspace: torch.Tensor, +# b_q_type_id: int, +# size_m: torch.SymInt, +# size_n: torch.SymInt, +# size_k: torch.SymInt, +# is_k_full: bool = True, +# use_atomic_add: bool = False, +# use_fp32_reduce: bool = False, +# is_zp_float: bool = False) -> torch.Tensor: +# return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + +# @register_fake("_C::marlin_qqq_gemm") +# def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, +# s_tok: torch.Tensor, s_ch: torch.Tensor, +# s_group: torch.Tensor, workspace: torch.Tensor, +# size_m: torch.SymInt, size_n: torch.SymInt, +# size_k: torch.SymInt) -> torch.Tensor: +# return torch.empty((size_m, size_n), +# dtype=torch.float16, +# device=a.device) + +# @register_fake("_C::marlin_gemm") +# def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, +# b_scales: torch.Tensor, workspace: torch.Tensor, +# size_m: torch.SymInt, size_n: torch.SymInt, +# size_k: torch.SymInt) -> torch.Tensor: +# return torch.empty((size_m, size_n), +# dtype=torch.float16, +# device=a.device) + +# @register_fake("_C::awq_dequantize") +# def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, +# zeros: torch.Tensor, split_k_iters: torch.SymInt, +# thx: int, thy: int) -> torch.Tensor: +# in_c = qweight.size(0) +# qout_c = qweight.size(1) +# out_c = qout_c * 8 +# return torch.empty((in_c, out_c), +# dtype=scales.dtype, +# device=scales.device) + +# @register_fake("_C::awq_gemm") +# def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, +# qzeros: torch.Tensor, scales: torch.Tensor, +# split_k_iters: torch.SymInt) -> torch.Tensor: +# num_in_feats = input.size(0) +# return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), +# dtype=input.dtype, +# device=input.device).sum(0) + +# @register_fake("_C::aqlm_gemm") +# def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, +# codebooks: torch.Tensor, scales: torch.Tensor, +# codebook_partition_sizes: list[int], +# bias: Optional[torch.Tensor]) -> torch.Tensor: +# out_features = codes.size(0) * codebooks.size(2) +# flat_input = input.reshape((-1, input.size(-1))) +# flat_output = torch.empty((flat_input.size(0), out_features), +# dtype=input.dtype, +# device=input.device) + +# output_sizes = list(input.shape) +# output_sizes.pop() +# output_sizes.append(-1) +# return flat_output.reshape(tuple(output_sizes)) + +# @register_fake("_C::aqlm_dequant") +# def _aqlm_dequant_fake( +# codes: torch.Tensor, codebooks: torch.Tensor, +# codebook_partition_sizes: list[int]) -> torch.Tensor: +# in_features = codes.size(1) * 8 +# out_features = codes.size(0) +# return torch.empty((out_features, in_features), +# dtype=codebooks.dtype, +# device=codebooks.device) + +# @register_fake("_C::machete_mm") +# def machete_mm_fake( +# a: torch.Tensor, +# # b_q Should be the tensor returned by machete_prepack_B +# b_q: torch.Tensor, +# b_type: ScalarType, +# out_type: Optional[torch.dtype] = None, +# b_group_scales: Optional[torch.Tensor] = None, +# b_group_zeros: Optional[torch.Tensor] = None, +# b_group_size: Optional[int] = None, +# b_channel_scales: Optional[torch.Tensor] = None, +# a_token_scales: Optional[torch.Tensor] = None, +# schedule: Optional[str] = None, +# ) -> torch.Tensor: +# m = a.size(0) +# n = b_q.size(1) +# return torch.empty((m, n), device=a.device, dtype=a.dtype) + +# @register_fake("_C::machete_prepack_B") +# def machete_prepack_B_fake( +# b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, +# group_scales_type: Optional[torch.dtype]) -> torch.Tensor: +# return torch.empty_like(b_q_weight, +# memory_format=torch.contiguous_format) + + +# if hasattr(torch.ops._C, "allspark_w8a16_gemm"): + +# @register_fake("_C::allspark_w8a16_gemm") +# def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, +# b_scales: torch.Tensor, +# b_qzeros: Optional[torch.Tensor], +# n: torch.SymInt, group_size: torch.SymInt, +# sm_count: torch.SymInt, +# sm_version: torch.SymInt, +# CUBLAS_M_THRESHOLD: torch.SymInt, +# has_zp: bool, +# n32k16_reorder: bool) -> torch.Tensor: +# m = a.size(0) +# return torch.empty((m, n), device=a.device, dtype=a.dtype) + + +if hasattr(torch.ops._C, "ggml_dequantize"): + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake( + W: torch.Tensor, + quant_type: int, + m: torch.SymInt, + n: torch.SymInt, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=X.dtype, device=W.device) + + @register_fake("_C::ggml_moe_a8") + def _ggml_moe_a8_fake( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: torch.SymInt, + top_k: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=torch.float16, + device=W.device) + + +if hasattr(torch.ops._C, "ggml_moe_a8_vec"): + + @register_fake("_C::ggml_moe_a8_vec") + def _ggml_moe_a8_vec_fake( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, + ) -> torch.Tensor: + tokens = X.size(0) + return torch.empty((tokens * top_k, row), + dtype=X.dtype, + device=W.device) + + +# cutlass +def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + + +def cutlass_blockwise_scaled_grouped_mm( + output: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scales_a: torch.Tensor, + scales_b: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, +): + torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a, + scales_b, problem_sizes, + expert_offsets) + + +def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, + alpha) + return out + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_block_fp8( + cuda_device_capability) + + +def cutlass_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + `cutlass_scaled_mm` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + + In order to support blockwise scaling like found in DeepSeek V3 we also + support extended "group" broadcast rules. We extend the numpy-style + broadcasting rules with the following rule: + "if the extent of a dimension in the source shape is between 1 and + corresponding extent in the target shape we repeat each element along + that dimension src_shape[dim] // target_shape[dim] times consecutively" + example if we have: + a = [[1, 2], and target_shape = (2, 4) + [3, 4]] + then we would expand a to: + a = [[1, 1, 2, 2], + [3, 3, 4, 4]] + currently we only support the case: + scale_a.shape * [1, 128] == a.shape + scale_b.shape * [128, 128] == b.shape + """ + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == b.shape[ + 1] and bias.dtype == out_dtype + + # m = a.shape[0] + # n = b.shape[1] + + # cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + # if current_platform.is_rocm() or not cutlass_compatible_b: + # from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa + # triton_scaled_mm) + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + # out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + # torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + # return out + #return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) + +def rocblas_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) + +def triton_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + best_config:Optional[list] = None) -> torch.Tensor: + + return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config) + +def triton_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.float16, + device: str = "cuda:0", + best_config:Optional[list] = None, + repeat:Optional[int] = 2): + return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat) + +def triton_blockint8_gemm_helper(m: int, + n: int, + k: int, + block_size:list=[128,128], + use_bias: bool=False, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda:0", + best_config:Optional[dict] = None, + repeat:Optional[int] = 2): + + return quant_tools.triton_blockint8_gemm_helper(m,n,k,block_size,use_bias,out_dtype,device,best_config,repeat) + + +def cutlass_scaled_mm_azp(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.numel( + ) == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, + azp, bias) + return out + + +def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_sparse_scaled_mm_supported( + cuda_device_capability) + + +def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) + +def cutlass_sparse_compress(a: torch.Tensor) \ + -> tuple[torch.Tensor, torch.Tensor]: + """ + Compresses a sparse matrix for use with Cutlass sparse operations. + + This function takes a dense tensor and compresses it into two components: + non-zero elements and metadata. The compressed representation is compatible + with Cutlass sparse kernels. + + Args: + a (torch.Tensor): + The input tensor to be compressed. Must have one of the following data types: + - `torch.int8` + - `torch.float8_e4m3fn` + - `torch.bfloat16` + - `torch.float16` + + Returns: + tuple[torch.Tensor, torch.Tensor]: + A tuple containing: + - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. + - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. + + Raises: + ValueError: If the compression operation fails. + + Notes: + - The `a_meta` tensor has a data type of `torch.uint8`. + - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`). + - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. + - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. + """ + assert (a.dtype in [ + torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 + ]) + assert (a.is_contiguous()) + + # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 + elemsPerMetaElem = 4 + assert (a.shape[1] % (2 * elemsPerMetaElem) == 0) + + return torch.ops._C.cutlass_sparse_compress(a) + + +def cutlass_scaled_sparse_mm( + a: torch.Tensor, + bt_nzs: torch.Tensor, + bt_meta: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Performs a scaled sparse matrix multiplication using Cutlass. + + Steps: + 1. Create a dense matrix `a` of shape (m, k) on the CUDA device: + `a = torch.randn((m, k), device='cuda')`. + + 2. Create a dense matrix `b` of shape (k, n) on the CUDA device: + `b = torch.randn((k, n), device='cuda')`. + + 3. Prune matrix `b` to 2:4 sparsity along the specified dimension: + `b = prune_to_2_4(b, dim=0)`. + + 4. Compress the transposed sparse matrix `b.t()`: + `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`. + + 5. Perform sparse matrix multiplication using the compressed matrix, + applying scaling factors for `a` and `b`, and the output data type: + `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`. + + Returns: + - The result of the scaled sparse matrix multiplication. + """ + assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == bt_nzs.shape[0] \ + and bias.dtype == out_dtype + + m = a.shape[0] + n = bt_nzs.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, + scale_b, bias) + + return out + + +def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, + output_permutation: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in topk_ids (token-expert mapping) and uses it to + compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation after the input is sorted with + input_permutation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + - input_permutation: Permutation that must be used to shuffle the input + before executing the MMs. + - output_permutation: Permutation that must be used to shuffle the output + after executing the MMs. + - blockscale_offsets: Optional argument passed for fp4 moe. Indices that + mark at which block scale index each expert begins + its computation. The number of block scale rows + computed with expert E is blockscale_offsets[E + 1] - + blockscale_offsets[E] + """ + return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, + output_permutation, + num_experts, n, k, + blockscale_offsets) + + +def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): + """ + Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. + This is used in MoE to permute the input tensor before performing grouped matrix multiplications. + """ + num_tokens_permuted = dst2src_map.shape[0] + output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]), + device=input_tensor.device, + dtype=input_tensor.dtype) + torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) + return output_tensor + + +def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + expert_num_tokens: torch.Tensor, + num_local_experts: int, padded_m: int, n: int, + k: int): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in expert_num_tokens (token count per expert) and + non_zero_expert_idxs (consecutive indices of experts with non-zero token + counts) and uses them to compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation. + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + """ + return torch.ops._C.get_cutlass_pplx_moe_mm_data( + expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, + num_local_experts, padded_m, n, k) + + +def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, a_strides: torch.Tensor, + b_strides: torch.Tensor, c_strides: torch.Tensor, + per_act_token: bool, per_out_ch: bool): + """ + A single grouped matrix multiplication used in CUTLASS-based fused MoE. + The function executes fp8-quantized OUT = AB matrix multiplication. + + - expert_offsets: Indices that mark at which token index each expert begins + its computation. The number of tokens computed with + expert E is expert_offsets[E + 1] - expert_offsets[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + - a/b/c_strides: The data strides passed to grouped matrix multiplication. + """ + return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch) + + +def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + alphas: torch.Tensor, problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, + out_dtype: torch.dtype, device: torch.device): + """ + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + the gemms for each combination based on the specified problem sizes. + + This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. + - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized + input and expert weights. + - a_/b_scales: The blockscales in FP8-E4M3 precision + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is + sf_offset[E+1] - sf_offset[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + """ + m_topk = a_tensors.shape[0] + n = b_tensors.shape[1] + c_shape = (m_topk, n) + c = torch.empty(c_shape, device=device, dtype=out_dtype) + torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, + b_scales, alphas, problem_sizes, + expert_offsets, sf_offsets) + return c.to(out_dtype) + + +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: list[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: list[int]) -> torch.Tensor: + return torch.ops._C.aqlm_dequant(codes, codebooks, + codebook_partition_sizes) + + +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) + + +# gptq_marlin +def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + size_k, size_n, num_bits) + return output + + +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = lightop.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) + + +# machete +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> list[str]: + return torch.ops._C.machete_supported_schedules( + a_type, b_type.id, group_scales_type, group_zeros_type, + channel_scales_type, token_scales_type, out_type) + + +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) + + +if hasattr(torch.ops._C, "permute_cols"): + + @register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + +# fp4 +def scaled_fp4_quant( + input: torch.Tensor, + input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in the sizzled layout. + """ + assert not current_platform.is_rocm() + assert input.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {input.ndim}.') + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert input.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=device, + dtype=torch.int32) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, + input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + packed MoE Inputs. + Args: + input_tensor: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + expert_offsets: The expert offsets tensor + blockscale_offsets: The blockscale offsets tensor + Outputs: + output: The quantized tensor in FP4 + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert not current_platform.is_rocm() + assert input_tensor.ndim == 2, ( + f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE Expert Quantization. This is used to prevent the kernel + # from running out of memory. This value can also be increased to support + # larger models. + MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE + m_numtopk, k = input_tensor.shape + + assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" + f"{MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" + f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.") + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty(m_numtopk, + k // 2, + device=input_tensor.device, + dtype=torch.uint8) + output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device) + torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, + input_global_scale, expert_offsets, + blockscale_offsets) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + +# fp8 +# def scaled_fp8_quant( +# input: torch.Tensor, +# scale: Optional[torch.Tensor] = None, +# num_token_padding: Optional[int] = None, +# scale_ub: Optional[torch.Tensor] = None, +# use_per_token_if_dynamic: bool = False, +# output: Optional[torch.Tensor] = None, +# ) -> tuple[torch.Tensor, torch.Tensor]: +# """ +# Quantize input tensor to FP8 and return quantized tensor and scale. + +# This function supports both static and dynamic quantization: If you +# provide the scale, it will use static scaling and if you omit it, +# the scale will be determined dynamically. The function also allows +# optional padding of the output tensors for downstream kernels that +# will benefit from padding. + +# Args: +# input: The input tensor to be quantized to FP8 +# scale: Optional scaling factor for the FP8 quantization +# scale_ub: Optional upper bound for scaling factor in dynamic +# per token case +# num_token_padding: If specified, pad the first dimension +# of the output to at least this value. +# use_per_token_if_dynamic: Whether to do per_tensor or per_token +# in the dynamic quantization case. + +# Returns: +# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and +# scaling factor. +# """ +# # This code assumes batch_dim and num_tokens are flattened +# assert (input.ndim == 2) +# shape: Union[tuple[int, int], torch.Size] = input.shape +# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz +# out_dtype: torch.dtype = current_platform.fp8_dtype() +# if num_token_padding: +# shape = (max(num_token_padding, input.shape[0]), shape[1]) +# if output is None: +# output = torch.empty(shape, device=input.device, dtype=out_dtype) +# else: +# assert num_token_padding is None, \ +# "padding not supported if output passed in" +# assert output.dtype == out_dtype + +# if scale is None: +# if use_per_token_if_dynamic: +# scale = torch.empty((shape[0], 1), +# device=input.device, +# dtype=torch.float32) +# torch.ops._C.dynamic_per_token_scaled_fp8_quant( +# output, input.contiguous(), scale, scale_ub) +# else: +# scale = torch.zeros(1, device=input.device, dtype=torch.float32) +# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) +# else: +# assert scale.numel() == 1, f"{scale.shape}" +# torch.ops._C.static_scaled_fp8_quant(output, input, scale) + +# return output, scale + + +# gptq allspark +def allspark_repack_weight( + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format + for Ampere W8A16 Fused Gemm kernel + + Args: + qweight: uint8 weight tensor, original k x n format. + scale: fp16/bf16 weight scale tensor, 1 x n format. + zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. + Must be provided for asymmetric quantization. + has_zp: if use symmetric quantization, has_zp = False. + if use asymmetric quantization, has_zp = True. + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + rearranged weight, scale, and optionally zero_point. + """ + K = qweight.shape[0] + N = qweight.shape[1] + N_32align = (N + 32 - 1) // 32 * 32 + + qweight_reorder = torch.empty((N_32align, K), + device=qweight.device, + dtype=qweight.dtype) + scale_reorder = torch.empty((1, N_32align), + device=scale.device, + dtype=scale.dtype) + zero_point_reorder = None + if has_zp: + assert zero_point is not None, ( + "zero_point must be provided for asymmetric quantization.") + zero_point_reorder = torch.empty((1, N_32align), + device=zero_point.device, + dtype=zero_point.dtype) + + torch.ops._C.rearrange_kn_weight_as_n32k16_order( + qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, + zero_point_reorder, K, N, N_32align) + + return qweight_reorder, scale_reorder, zero_point_reorder + + +def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], n: int, + group_size: int, sm_count: int, sm_version: int, + CUBLAS_M_THRESHOLD: int, has_zp: bool, + n32k16_reorder: bool) -> torch.Tensor: + + return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, + n, group_size, sm_count, + sm_version, CUBLAS_M_THRESHOLD, + has_zp, n32k16_reorder) + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp + is None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) + return output, input_scales, input_azp + + +# qqq ops +def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + +# gguf +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, + dtype: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) + + +def ggml_mul_mat_vec_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) + + +def ggml_mul_mat_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) + + +def ggml_moe_a8( + X: torch.Tensor, + W: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + quant_type: int, + row: int, + top_k: int, + tokens: int, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids, + num_tokens_post_padded, quant_type, row, + top_k, tokens) + + +def ggml_moe_a8_vec( + X: torch.Tensor, + W: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + quant_type: int, + row: torch.SymInt, + tokens: torch.SymInt, +) -> torch.Tensor: + return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, + tokens) + + +def ggml_moe_get_block_size(quant_type: int) -> int: + return torch.ops._C.ggml_moe_get_block_size(quant_type) + + +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): + torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, + delta_softplus, query_start_loc, + cache_indices, has_initial_state, + ssm_states, pad_slot_id) + + +# ROCm skinny gemms +def LLMM1(a: torch.Tensor, b: torch.Tensor, + rows_per_block: int) -> torch.Tensor: + return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) + + +def wvSplitK(a: torch.Tensor, b: torch.Tensor, cu_count: int) -> torch.Tensor: + return torch.ops._rocm_C.wvSplitK(a, b, cu_count) + + +def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, + cu_count: int) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), + dtype=out_dtype, + device=b.device) + torch.ops._rocm_C.wvSplitKQ(a, b, out, scale_a, scale_b, cu_count) + return out + + +# moe +def moe_sum(input: torch.Tensor, output: torch.Tensor): + torch.ops._moe_C.moe_sum(input, output) +def moe_sum_opt1(input: torch.Tensor, output: torch.Tensor): + torch.ops._moe_C.moe_sum_opt1(input, output) + +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, + b_qweight: torch.Tensor, b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, top_k: int, + BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, + bit: int) -> torch.Tensor: + if not current_platform.is_cuda(): + raise NotImplementedError( + "The optimized moe_wna16_gemm kernel is only " + "available on CUDA platforms") + torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, + b_qzeros, topk_weights, sorted_token_ids, + experts_ids, num_tokens_post_pad, top_k, + BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, + bit) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor) -> None: + torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indices, + gating_output) + + +def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], + b_qweight: torch.Tensor, b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, moe_block_size: int, + top_k: int, mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.ops._moe_C.moe_wna16_marlin_gemm( + input, output, b_qweight, b_scales, global_scale, b_qzeros, g_idx, + perm, workspace, sorted_token_ids, expert_ids, num_tokens_past_padded, + topk_weights, moe_block_size, top_k, mul_topk_weights, is_ep, + b_q_type.id, size_m, size_n, size_k, is_k_full, use_atomic_add, + use_fp32_reduce, is_zp_float) + + +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, size_k: torch.SymInt, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + @register_fake("_moe_C::moe_wna16_marlin_gemm") + def moe_wna16_marlin_gemm_fake(input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, top_k: int, + mul_topk_weights: bool, is_ep: bool, + b_q_type: ScalarType, size_m: int, + size_n: int, size_k: int, is_k_full: bool, + use_atomic_add: bool, use_fp32_reduce: bool, + is_zp_float: bool) -> torch.Tensor: + return torch.empty((size_m * top_k, size_n), + dtype=input.dtype, + device=input.device) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale) + + +def reshape_and_cache_cuda( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache_cuda(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, v_scale) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + + +def concat_and_cache_mla( + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, + slot_mapping, kv_cache_dtype, + scale) + + +def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def copy_blocks_mla(kv_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) + + +def gather_cache(src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None) -> None: + torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, + cu_seq_lens, batch_size, seq_starts) + + +def get_device_attribute(attribute: int, device: int) -> int: + return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + # ruff: noqa: E501 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device) + + +# custom ar +def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, + rank: int, fully_connected: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, + fully_connected) + + +def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, + reg_buffer_sz_bytes: int) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, + reg_buffer_sz_bytes) + + +def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + +def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + +def register_buffer(fa: int, ipc_tensors: list[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + + +def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa: int, handles: list[list[int]], + offsets: list[list[int]]) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: + return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) + + +def open_mem_handle(mem_handle: torch.Tensor): + return torch.ops._C_custom_ar.open_mem_handle(mem_handle) + + +def free_shared_buffer(ptr: int) -> None: + torch.ops._C_custom_ar.free_shared_buffer(ptr) + + +def read_cache( + keys: torch.Tensor, + values: torch.Tensor, + key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + slot_mapping: torch.Tensor, + kv_cache_dtype: str +) -> None: + torch.ops._C_cache_ops.read_cache(keys, values, key_caches, + value_caches, slot_mapping, + kv_cache_dtype) + +def write_cache_multi_layers( + keys: torch.Tensor, + values: torch.Tensor, + key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + slot_mapping: torch.Tensor, + kv_cache_dtype: str +) -> None: + torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches, + value_caches, slot_mapping, + kv_cache_dtype) + +# quick all reduce +def init_custom_qr(rank: int, + world_size: int, + qr_max_size: Optional[int] = None) -> int: + return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) + + +def qr_destroy(fa: int) -> None: + torch.ops._C_custom_ar.qr_destroy(fa) + + +def qr_all_reduce(fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool = False) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, + cast_bf2half) + + +def qr_get_handle(fa: int) -> torch.Tensor: + return torch.ops._C_custom_ar.qr_get_handle(fa) + + +def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: + return torch.ops._C_custom_ar.qr_open_handles(fa, handles) + + +def qr_max_size() -> int: + return torch.ops._C_custom_ar.qr_max_size() + + +def get_flash_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, +# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, +# seq_lens: torch.Tensor, page_table: torch.Tensor, +# scale: float) -> torch.Tensor: +# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, +# seq_lens, page_table, scale) +# return out + + +def moe_fused_gate( + input_tensor, + bias, + num_expert_group, + topk_group, + topk, + n_share_experts_fusion=0, + routed_scaling_factor=0, +): + # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion + # it split group of expert into num_expert_group, and use top2 expert weight sum in each group + # as the group weight to select exerpt groups and then select topk experts within the selected groups + # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts + # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now. + # for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk + # n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert + # routed_scaling_factor: if > 0, the last expert will be scaled by this factor + return torch.ops._moe_C.moe_fused_gate( + input_tensor, + bias, + num_expert_group, + topk_group, + topk, + n_share_experts_fusion, + routed_scaling_factor, + ) + +if hasattr(torch.ops._moe_C, "moe_fused_gate"): + + @register_fake("_moe_C::moe_fused_gate") + def moe_fused_gate_fake( + input_tensor: torch.Tensor, + bias: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + n_share_experts_fusion: int, + routed_scaling_factor: int, + ): + return torch.empty((input_tensor.size(0), topk), + dtype=input_tensor.dtype, + device=input_tensor.device), \ + torch.empty((input_tensor.size(0), topk), + dtype=input_tensor.dtype, + device=input_tensor.device) + + +if hasattr(torch.ops._C, "weight_packed_linear"): + + @register_fake("_C::weight_packed_linear") + def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor, + bias: Optional[torch.Tensor], + is_vnni: bool) -> torch.Tensor: + return torch.empty((mat1.size(0), mat2.size(0)), + dtype=mat1.dtype, + device=mat2.device) + + +if hasattr(torch.ops._C, "fused_experts_cpu"): + + @register_fake("_C::fused_experts_cpu") + def fused_experts_cpu_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool, + use_int8_w8a8: bool, + use_fp8_w8a16: bool, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + block_size: Optional[list[int]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + is_vnni: bool, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): + + @register_fake("_C::int8_scaled_mm_with_quant") + def int8_scaled_mm_with_quant_fake( + mat1: torch.Tensor, + mat2: torch.Tensor, + scales2: torch.Tensor, + bias: Optional[torch.Tensor], + out_dtype: torch.dtype, + is_vnni: bool, + ) -> torch.Tensor: + M = mat1.size(0) + N = mat2.size(0) + return torch.empty((M, N), dtype=out_dtype) + +direct_register_custom_op( + op_name="awq_gemm", + op_func=awq_gemm, + mutates_args=[], + fake_impl=awq_gemm_fake, +) \ No newline at end of file diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py new file mode 100644 index 0000000..7533bf5 --- /dev/null +++ b/vllm/_ipex_ops.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import intel_extension_for_pytorch as ipex +except ImportError as e: + logger.warning("Import error msg: %s", e.msg) + + +class ipex_ops: + + @staticmethod + def _reshape_activation_tensor( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + num = x.size(0) + d = x.size(1) // 2 + x = x.reshape(num, 2, d) + x1, x2 = torch.chunk(x, chunks=2, dim=1) + x1 = x1.reshape(num, d) + x2 = x2.reshape(num, d) + return x1, x2 + + @staticmethod + def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.silu_and_mul(x, out) + + @staticmethod + def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) + + @staticmethod + def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + ipex.llm.modules.PagedAttention.single_query_kv_attention( + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + num_queries_per_tokens, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + ipex.llm.modules.PagedAttention.single_query_kv_attention( + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + num_queries_per_tokens, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def rotary_embedding( + positions: torch.Tensor, # [batch_size, seq_len] + query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] + is_neox: bool, + ) -> None: + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) + + @staticmethod + def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) + + @staticmethod + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) + + @staticmethod + def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, + epsilon, True) + input.copy_(tmp) + + @staticmethod + def varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + seqlen_q: torch.Tensor, + seqlen_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + pdropout: float, + softmax_scale: float, + zero_tensors: bool, + is_causal: bool, + return_softmax: bool, + gen_: torch.Generator, + window_size_left: float, + window_size_right: float, + logits_soft_cap: float, + ) -> None: + if ipex.__version__.endswith("cpu"): + if logits_soft_cap != 0.0: + raise ValueError("IPEX CPU does not support logits_soft_cap") + assert alibi_slopes is None + assert window_size_left < 0 and window_size_right < 0 + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), + seqlen_k.int(), max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, + gen_) + else: # XPU build + ipex.llm.functional.varlen_attention( + query.contiguous(), key.contiguous(), value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q, + max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal, + return_softmax, gen_, window_size_left, window_size_right, + logits_soft_cap) + + @staticmethod + def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + assert kv_cache_dtype == "auto" + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: Optional[torch.Tensor] = None, + v_scale: Optional[torch.Tensor] = None, + k_scale_float: float = 1.0, + v_scale_float: float = 1.0, + ) -> None: + assert kv_cache_dtype == "auto" + # TODO: support FP8 kv cache. + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def flash_attn_varlen_func( + out: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seqused_k: torch.Tensor, # we don't support this in ipex kernel + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + causal: bool, + block_table: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + window_size: Optional[list[int]] = None, + softcap: Optional[float] = 0.0, + cu_seqlens_k: Optional[torch.Tensor] = None, + # The following parameters are not used in ipex kernel currently, + # we keep API compatible to CUDA's. + scheduler_metadata=None, + fa_version: int = 2, + q_descale=None, + k_descale=None, + v_descale=None, + num_splits=0, + ): + if cu_seqlens_k is None: + # cu_seqlens_k is not used in ipex kernel. + cu_seqlens_k = torch.cumsum(seqused_k, dim=0) + cu_seqlens_k = torch.cat([ + torch.tensor([0], device=seqused_k.device, dtype=torch.int32), + cu_seqlens_k + ]).to(torch.int32) + + real_window_size: tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + q.contiguous(), + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + block_table, + alibi_slopes, + softcap=softcap, + window_size_left=real_window_size[0], + window_size_right=real_window_size[1], + k_scale=1.0, + v_scale=1.0, + ) + + @staticmethod + def get_scheduler_metadata( + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + ) -> None: + logger.warning_once( + "get_scheduler_metadata is not implemented for ipex_ops, " + "returning None.") + return None + + @staticmethod + def copy_blocks(key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.xpu.copy_blocks( # type: ignore + key_caches, + value_caches, + block_mapping, + ) + + @staticmethod + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore diff --git a/vllm/_moe_C.abi3.so b/vllm/_moe_C.abi3.so new file mode 100755 index 0000000..d1934ea Binary files /dev/null and b/vllm/_moe_C.abi3.so differ diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 0000000..9753a08 --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: tuple[int, ...] + # Per sampled token: + prompt_mapping: tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 0000000..7b68588 --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[int, T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: int, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def capacity(self) -> int: + raise NotImplementedError + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> dict[int, Any]: + raise NotImplementedError + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 0000000..8135b54 --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod + + +class AdapterRequest(ABC): + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self) -> int: + raise NotImplementedError + + def __post_init__(self) -> None: + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 0000000..a1a56b6 --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id) + + +## worker functions +def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000..07e85d1 --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + raise NotImplementedError + + @abstractmethod + def set_active_adapters(self, requests: set[Any], + mapping: Optional[Any]) -> None: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> set[int]: + raise NotImplementedError diff --git a/vllm/assets/__init__.py b/vllm/assets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py new file mode 100644 index 0000000..1c16230 --- /dev/null +++ b/vllm/assets/audio.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal +from urllib.parse import urljoin + +import numpy.typing as npt + +from vllm.utils import PlaceholderModule + +from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +ASSET_DIR = "multimodal_asset" + +AudioAssetName = Literal["winning_call", "mary_had_lamb"] + + +@dataclass(frozen=True) +class AudioAsset: + name: AudioAssetName + + @property + def filename(self) -> str: + return f"{self.name}.ogg" + + @property + def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: + audio_path = get_vllm_public_assets(filename=self.filename, + s3_prefix=ASSET_DIR) + return librosa.load(audio_path, sr=None) + + def get_local_path(self) -> Path: + return get_vllm_public_assets(filename=self.filename, + s3_prefix=ASSET_DIR) + + @property + def url(self) -> str: + return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/assets/base.py b/vllm/assets/base.py new file mode 100644 index 0000000..31cde43 --- /dev/null +++ b/vllm/assets/base.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import vllm.envs as envs +from vllm.connections import global_http_connection + +VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com" + + +def get_cache_dir() -> Path: + """Get the path to the cache for storing downloaded assets.""" + path = Path(envs.VLLM_ASSETS_CACHE) + path.mkdir(parents=True, exist_ok=True) + + return path + + +@lru_cache +def get_vllm_public_assets(filename: str, + s3_prefix: Optional[str] = None) -> Path: + """ + Download an asset file from ``s3://vllm-public-assets`` + and return the path to the downloaded file. + """ + asset_directory = get_cache_dir() / "vllm_public_assets" + asset_directory.mkdir(parents=True, exist_ok=True) + + asset_path = asset_directory / filename + if not asset_path.exists(): + if s3_prefix is not None: + filename = s3_prefix + "/" + filename + global_http_connection.download_file( + f"{VLLM_S3_BUCKET_URL}/{filename}", + asset_path, + timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT) + + return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py new file mode 100644 index 0000000..c977242 --- /dev/null +++ b/vllm/assets/image.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Literal + +import torch +from PIL import Image + +from .base import get_vllm_public_assets + +VLM_IMAGES_DIR = "vision_model_images" + +ImageAssetName = Literal["stop_sign", "cherry_blossom"] + + +@dataclass(frozen=True) +class ImageAsset: + name: ImageAssetName + + @property + def pil_image(self) -> Image.Image: + image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", + s3_prefix=VLM_IMAGES_DIR) + return Image.open(image_path) + + @property + def image_embeds(self) -> torch.Tensor: + """ + Image embeddings, only used for testing purposes with llava 1.5. + """ + image_path = get_vllm_public_assets(filename=f"{self.name}.pt", + s3_prefix=VLM_IMAGES_DIR) + return torch.load(image_path, map_location="cpu", weights_only=True) diff --git a/vllm/assets/video.py b/vllm/assets/video.py new file mode 100644 index 0000000..1641212 --- /dev/null +++ b/vllm/assets/video.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, ClassVar, Literal, Optional + +import cv2 +import numpy as np +import numpy.typing as npt +from huggingface_hub import hf_hub_download +from PIL import Image + +from vllm.utils import PlaceholderModule + +from .base import get_cache_dir + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + + +@lru_cache +def download_video_asset(filename: str) -> str: + """ + Download and open an image from huggingface + repo: raushan-testing-hf/videos-test + """ + video_directory = get_cache_dir() / "video-example-data" + video_directory.mkdir(parents=True, exist_ok=True) + + video_path = video_directory / filename + video_path_str = str(video_path) + if not video_path.exists(): + video_path_str = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", + filename=filename, + repo_type="dataset", + cache_dir=video_directory, + ) + return video_path_str + + +def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frames = [] + + num_frames = num_frames if num_frames > 0 else total_frames + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + for idx in range(total_frames): + ok = cap.grab() # next img + if not ok: + break + if idx in frame_indices: # only decompress needed + ret, frame = cap.retrieve() + if ret: + frames.append(frame) + + frames = np.stack(frames) + if len(frames) < num_frames: + raise ValueError(f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})") + return frames + + +def video_to_pil_images_list(path: str, + num_frames: int = -1) -> list[Image.Image]: + frames = video_to_ndarrays(path, num_frames) + return [ + Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + for frame in frames + ] + + +def video_get_metadata(path: str) -> dict[str, Any]: + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames / fps if fps > 0 else 0 + + metadata = { + "total_num_frames": total_frames, + "fps": fps, + "duration": duration, + "video_backend": "opencv" + } + return metadata + + +VideoAssetName = Literal["baby_reading"] + + +@dataclass(frozen=True) +class VideoAsset: + name: VideoAssetName + num_frames: int = -1 + + _NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = { + "baby_reading": "sample_demo_1.mp4", + } + + @property + def filename(self) -> str: + return self._NAME_TO_FILE[self.name] + + @property + def pil_images(self) -> list[Image.Image]: + video_path = download_video_asset(self.filename) + ret = video_to_pil_images_list(video_path, self.num_frames) + return ret + + @property + def np_ndarrays(self) -> npt.NDArray: + video_path = download_video_asset(self.filename) + ret = video_to_ndarrays(video_path, self.num_frames) + return ret + + @property + def metadata(self) -> dict[str, Any]: + video_path = download_video_asset(self.filename) + ret = video_get_metadata(video_path) + return ret + + def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: + """ + Read audio data from the video asset, used in Qwen2.5-Omni examples. + + See also: examples/offline_inference/qwen2_5_omni/only_thinker.py + """ + video_path = download_video_asset(self.filename) + return librosa.load(video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py new file mode 100644 index 0000000..3440405 --- /dev/null +++ b/vllm/attention/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend + +__all__ = [ + "Attention", + "AttentionBackend", + "AttentionMetadata", + "AttentionType", + "AttentionMetadataBuilder", + "Attention", + "AttentionState", + "get_attn_backend", +] diff --git a/vllm/attention/backends/__init__.py b/vllm/attention/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py new file mode 100644 index 0000000..990ea05 --- /dev/null +++ b/vllm/attention/backends/abstract.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, fields +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, + Protocol, Set, Tuple, Type, TypeVar) + +import torch + +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) + + +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" + raise NotImplementedError + + @abstractmethod + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionLayer(Protocol): + + _q_scale: torch.Tensor + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float + _prob_scale: torch.Tensor + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, + group_shape: tuple[int, int]): + """ + Does this attention implementation support fused output quantization. + This is used by the AttnFusionPass to only fuse output quantization + onto implementations that support it. + + TODO(luka) merge parameters into QuantDescriptor + :param dtype: quantized dtype + :param static: static or dynamic quantization + :param group_shape: quant group shape. (-1, -1) for per-tensor. + :return: is fusion supported for this type of quantization + """ + return False + + +class MLAAttentionImpl(AttentionImpl[T], Generic[T]): + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + hidden_states_or_cq: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: + return kv_cache_dtype != "auto" diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000..bccc984 --- /dev/null +++ b/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,469 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "BLOCK_SPARSE_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Max number of query tokens for among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + block_tables_list: Optional[List[int]] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + block_tables_list=self.block_tables_list + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + block_tables_list=self.block_tables_list + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_kv_heads + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for BlocksparseFlashAttentionImpl") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + assert output is not None + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16_BW.json b/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16_BW.json new file mode 100644 index 0000000..2f74439 --- /dev/null +++ b/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16_BW.json @@ -0,0 +1,1194 @@ +{ + "1": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 87.52100169658661 + }, + "100": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 64, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 109.2820018529892 + }, + "400": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 32, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 179.92249131202698 + }, + "700": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 32, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 266.0830020904541 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 200.48299431800842 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.32299768924713 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 228.48299145698547 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 189.52250480651855 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 194.48299705982208 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 197.84200191497803 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 193.44200193881989 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 189.28200006484985 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.2430055141449 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 258.80300998687744 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 274.40300583839417 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 208.96300673484802 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 213.60298991203308 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 205.84198832511902 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 211.2025022506714 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 214.8820012807846 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 269.92300152778625 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 408.0055058002472 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 215.84299206733704 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.32299494743347 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 215.52199125289917 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.40299928188324 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 226.8030047416687 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 267.4434781074524 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 290.08299112319946 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 273.6029922962189 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 306.1639964580536 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 389.28499817848206 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 270.24298906326294 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 464.0049934387207 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 271.3640034198761 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 279.68400716781616 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 281.60300850868225 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 462.48602867126465 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 383.3639919757843 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 469.36601400375366 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 287.8440022468567 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 295.12351751327515 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 445.4450011253357 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 299.5240092277527 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 305.28348684310913 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 301.44399404525757 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 307.4440062046051 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 309.6030056476593 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 447.68598675727844 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 533.6059927940369 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 327.6839852333069 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 302.0839989185333 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 306.4830005168915 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 324.4040012359619 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 360.4849874973297 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 220.16200423240662 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 220.9630012512207 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 220.80199420452118 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 232.6429933309555 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 232.48299956321716 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 233.76299440860748 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.40299928188324 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 244.3230003118515 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 245.1229989528656 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 246.40299379825592 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 246.9629943370819 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 256.96301460266113 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 258.0829858779907 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 258.7229907512665 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 259.5230042934418 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 269.12298798561096 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 269.60399746894836 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 270.4029977321625 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 271.84298634529114 + } +} \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16_K100AI.json b/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16_K100AI.json new file mode 100644 index 0000000..6111089 --- /dev/null +++ b/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16_K100AI.json @@ -0,0 +1,1190 @@ +{ + "1": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 32.48000144958496 + }, + "100": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 64, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + } + }, + "best_us": 51.04000121355057 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 81.11999928951263 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 87.20000088214874 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 90.55999666452408 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 89.43849802017212 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 92.6399976015091 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 94.71999853849411 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 104.80000078678131 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 105.76000064611435 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 108.0000028014183 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 107.19999670982361 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 108.8000014424324 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.07999628782272 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 109.03950035572052 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 122.23999947309494 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 123.03999811410904 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 121.91999703645706 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 121.44000083208084 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 123.19999933242798 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 124.64000284671783 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 136.48000359535217 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 137.7590000629425 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 137.43999600410461 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 138.2399946451187 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 149.75999295711517 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 149.75999295711517 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 151.19999647140503 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 152.16000378131866 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 163.83999586105347 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 164.32000696659088 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 164.48000073432922 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 165.27999937534332 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 177.279993891716 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 177.91900038719177 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 178.24000120162964 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 179.36000227928162 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 190.5599981546402 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 191.3589984178543 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 191.52000546455383 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 191.19900465011597 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 202.55999267101288 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 202.72000133991241 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 203.67999374866486 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 205.11899888515472 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 216.3199931383133 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 217.1200066804886 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 217.75999665260315 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 218.07999908924103 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 229.2799949645996 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 229.76000607013702 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.96000742912292 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 230.07799685001373 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 227.84000635147095 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 242.88000166416168 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 243.20000410079956 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 243.6790019273758 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 241.28000438213348 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 255.0399899482727 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 255.51998615264893 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 255.99899888038635 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 254.55999374389648 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 268.15998554229736 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 268.640011548996 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 269.76001262664795 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 269.76001262664795 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 281.76000714302063 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 282.24000334739685 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 283.03951025009155 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 283.03998708724976 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 293.92001032829285 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 295.199990272522 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 295.6799864768982 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 295.6790030002594 + } +} \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16__default.json b/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16__default.json new file mode 100644 index 0000000..d11c122 --- /dev/null +++ b/vllm/attention/backends/configs/QH=16_KVH=1_QKD=576_VD=512_fp16__default.json @@ -0,0 +1,1194 @@ +{ + "1": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 87.52100169658661 + }, + "100": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 64, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 109.2820018529892 + }, + "400": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 32, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 179.92249131202698 + }, + "700": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 32, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + } + }, + "best_us": 266.0830020904541 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 200.48299431800842 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.32299768924713 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 228.48299145698547 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 189.52250480651855 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 194.48299705982208 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 197.84200191497803 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 193.44200193881989 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 189.28200006484985 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.2430055141449 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 258.80300998687744 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 274.40300583839417 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 208.96300673484802 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 213.60298991203308 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 205.84198832511902 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 211.2025022506714 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 214.8820012807846 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 269.92300152778625 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 408.0055058002472 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 215.84299206733704 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.32299494743347 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 215.52199125289917 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.40299928188324 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 226.8030047416687 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 267.4434781074524 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 290.08299112319946 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 273.6029922962189 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 306.1639964580536 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 389.28499817848206 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 270.24298906326294 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 464.0049934387207 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 271.3640034198761 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 279.68400716781616 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 281.60300850868225 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 462.48602867126465 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 383.3639919757843 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 469.36601400375366 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 287.8440022468567 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 295.12351751327515 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 445.4450011253357 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 299.5240092277527 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 305.28348684310913 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 301.44399404525757 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 307.4440062046051 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 309.6030056476593 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 447.68598675727844 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 533.6059927940369 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 327.6839852333069 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 302.0839989185333 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 306.4830005168915 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 324.4040012359619 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 360.4849874973297 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 220.16200423240662 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 220.9630012512207 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 220.80199420452118 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 232.6429933309555 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 232.48299956321716 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 233.76299440860748 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 234.40299928188324 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 244.3230003118515 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 245.1229989528656 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 246.40299379825592 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 246.9629943370819 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 256.96301460266113 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 258.0829858779907 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 258.7229907512665 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 259.5230042934418 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 269.12298798561096 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 269.60399746894836 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 270.4029977321625 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 271.84298634529114 + } + } \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_BW.json b/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_BW.json new file mode 100644 index 0000000..ade83c0 --- /dev/null +++ b/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_BW.json @@ -0,0 +1,1186 @@ +{ + "1": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 60.800500214099884 + }, + "100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.16099715232849 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.64000242948532 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 81.76100254058838 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 84.00099724531174 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 82.88100361824036 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 85.12099832296371 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 86.72100305557251 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 96.00099921226501 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 96.80099785327911 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 98.08100014925003 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 97.12100028991699 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 98.40100258588791 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 99.28150475025177 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 98.72099757194519 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.2820018529892 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 109.44200307130814 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.76099967956543 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 108.96199941635132 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.48150062561035 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 111.28149926662445 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 121.12099677324295 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 122.08200246095657 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 122.88100272417068 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 122.5619986653328 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 132.64200091362 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.12149047851562 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.4419995546341 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 134.08200442790985 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 144.16199922561646 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 144.80100572109222 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 145.6020027399063 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 145.76199650764465 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 156.32200241088867 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 157.44100511074066 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 157.76200592517853 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 158.0819934606552 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 168.4820055961609 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 169.44199800491333 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 169.44199800491333 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 169.12199556827545 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 179.52199280261993 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 180.48200011253357 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 180.6419938802719 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 181.12200498580933 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 192.16251373291016 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 192.80199706554413 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 193.121999502182 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 193.3625042438507 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.48200404644012 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.803004860878 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.0019929409027 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 204.96299862861633 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.803004860878 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 215.84199368953705 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 216.80200099945068 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 216.96199476718903 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 217.2829955816269 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.48300635814667 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.32299768924713 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.80299389362335 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 229.28300499916077 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 239.8429960012436 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 240.16299843788147 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 241.60300195217133 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 242.08299815654755 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 252.8029978275299 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 253.44300270080566 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 253.28299403190613 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 253.92299890518188 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 264.164000749588 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 265.12300968170166 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 265.44299721717834 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 266.00348949432373 + } +} \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_K100AI.json b/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_K100AI.json new file mode 100644 index 0000000..9400548 --- /dev/null +++ b/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_K100AI.json @@ -0,0 +1,1190 @@ +{ + "1": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + } + }, + "best_us": 30.559999868273735 + }, + "100": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 64, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + } + }, + "best_us": 48.48000034689903 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 77.27999985218048 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 83.03999900817871 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 86.40000224113464 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 85.60000360012054 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 88.16000074148178 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 90.08000046014786 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 100.3199964761734 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 101.59949958324432 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 102.79950499534607 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 102.88000106811523 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 104.47999835014343 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 105.27999699115753 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 104.47999835014343 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 117.11999773979187 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 117.91999638080597 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 117.76000261306763 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 117.27949976921082 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 119.03999745845795 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 120.4800009727478 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 131.04000687599182 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 132.1599930524826 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.27999413013458 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.59999656677246 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 144.31999623775482 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 145.28000354766846 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 146.08000218868256 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 145.91999351978302 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 157.75899589061737 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 158.39999914169312 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 159.04000401496887 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 159.9999964237213 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 171.6800034046173 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 172.31999337673187 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 173.11950027942657 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 173.2800006866455 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 184.15899574756622 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 185.44000387191772 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 185.59999763965607 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 185.7600063085556 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 196.6399997472763 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 197.76000082492828 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 197.60000705718994 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 198.95949959754944 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 210.7200026512146 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 210.7200026512146 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 211.5200012922287 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 211.19999885559082 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 223.83999824523926 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 224.16000068187714 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 222.56000339984894 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 223.51999580860138 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 222.71999716758728 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 236.64000630378723 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 236.32000386714935 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 237.2799962759018 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 234.40000414848328 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 249.27900731563568 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 248.48000705242157 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 250.2399981021881 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 248.31999838352203 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 261.27898693084717 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 261.4400088787079 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 262.56000995635986 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 263.2000148296356 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 275.04000067710876 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 276.15898847579956 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 275.519996881485 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 276.15898847579956 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 288.1599962711334 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 287.9999876022339 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 288.1599962711334 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 289.11998867988586 + } +} \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_default.json b/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_default.json new file mode 100644 index 0000000..9ed31b1 --- /dev/null +++ b/vllm/attention/backends/configs/QH=4_KVH=1_QKD=576_VD=512_fp16_default.json @@ -0,0 +1,1186 @@ +{ + "1": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 60.800500214099884 + }, + "100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.16099715232849 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.64000242948532 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 81.76100254058838 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 84.00099724531174 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 82.88100361824036 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 85.12099832296371 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 86.72100305557251 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 96.00099921226501 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 96.80099785327911 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 98.08100014925003 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 97.12100028991699 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 98.40100258588791 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 99.28150475025177 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 98.72099757194519 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.2820018529892 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 109.44200307130814 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.76099967956543 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 108.96199941635132 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.48150062561035 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 111.28149926662445 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 121.12099677324295 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 122.08200246095657 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 122.88100272417068 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 122.5619986653328 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 132.64200091362 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.12149047851562 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.4419995546341 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 134.08200442790985 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 144.16199922561646 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 144.80100572109222 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 145.6020027399063 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 145.76199650764465 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 156.32200241088867 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 157.44100511074066 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 157.76200592517853 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 158.0819934606552 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 168.4820055961609 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 169.44199800491333 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 169.44199800491333 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 169.12199556827545 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 179.52199280261993 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 180.48200011253357 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 180.6419938802719 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 181.12200498580933 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 192.16251373291016 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 192.80199706554413 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 193.121999502182 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 193.3625042438507 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.48200404644012 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.803004860878 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.0019929409027 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 204.96299862861633 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 204.803004860878 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 215.84199368953705 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 216.80200099945068 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 216.96199476718903 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 217.2829955816269 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.48300635814667 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.32299768924713 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 228.80299389362335 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 229.28300499916077 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 239.8429960012436 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 240.16299843788147 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 241.60300195217133 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 242.08299815654755 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 252.8029978275299 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 253.44300270080566 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 253.28299403190613 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 253.92299890518188 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 264.164000749588 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 265.12300968170166 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 265.44299721717834 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 266.00348949432373 + } + } \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_BW.json b/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_BW.json new file mode 100644 index 0000000..5728930 --- /dev/null +++ b/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_BW.json @@ -0,0 +1,1186 @@ +{ + "1": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 42.40100085735321 + }, + "100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.48099958896637 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.80100202560425 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 81.76100254058838 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 85.76100319623947 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 83.68100225925446 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 85.44149994850159 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 86.88099682331085 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 96.16100043058395 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 96.96099907159805 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 99.36200082302094 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 98.56099635362625 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 99.20100122690201 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 100.64099729061127 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 257.9230070114136 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.44150388240814 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.88100075721741 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.72099953889847 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.44099724292755 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 111.84199899435043 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 112.32200264930725 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 359.36400294303894 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 360.164999961853 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 362.8849983215332 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 153.92200648784637 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.92199575901031 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 201.68299973011017 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 134.88100469112396 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 135.8419954776764 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 178.24199795722961 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 171.76198959350586 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 146.7210054397583 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 601.127028465271 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 681.9279789924622 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 284.00298953056335 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 285.1240038871765 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 282.803475856781 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 280.80400824546814 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 297.9240119457245 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 280.80400824546814 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 621.288001537323 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 289.28399085998535 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 290.4840111732483 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 295.2040135860443 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 296.644002199173 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 290.88300466537476 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 273.4430134296417 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 316.88401103019714 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 278.8830101490021 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 558.8070154190063 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 305.52399158477783 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 298.08300733566284 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 303.84400486946106 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 309.2834949493408 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 388.00498843193054 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 361.76449060440063 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 636.1680030822754 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 357.76448249816895 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 375.6850063800812 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 362.964004278183 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 362.0845079421997 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 361.92500591278076 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 367.52501130104065 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 690.887987613678 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 375.2039968967438 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 376.32399797439575 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 389.4439935684204 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 387.20399141311646 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 414.48551416397095 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 383.20451974868774 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 446.56500220298767 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 627.2079944610596 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 454.00500297546387 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 428.00599336624146 + } +} \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_K100AI.json b/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_K100AI.json new file mode 100644 index 0000000..c7dfdd4 --- /dev/null +++ b/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_K100AI.json @@ -0,0 +1,1190 @@ +{ + "1": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + } + }, + "best_us": 31.360000371932983 + }, + "100": { + "kernel_kind": "v1_2stages_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "stage2": { + "BLOCK_N": 64, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + } + }, + "best_us": 48.79999905824661 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 78.5600021481514 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 84.63999629020691 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 87.99999952316284 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 86.87999844551086 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 89.75999802350998 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 91.839998960495 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 102.08000242710114 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 103.35999727249146 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 104.80000078678131 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 104.3199971318245 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 105.43999820947647 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 106.39999806880951 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 106.55999928712845 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 119.03949826955795 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 119.84000355005264 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 119.35999989509583 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 118.56000125408173 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 120.4800009727478 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 121.11999839544296 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.27999413013458 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 134.39999520778656 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 134.88000631332397 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 135.04000008106232 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 146.55999839305878 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 146.55999839305878 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 148.00000190734863 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 148.3200043439865 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 159.9999964237213 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 159.9999964237213 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 160.64000129699707 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 162.08000481128693 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 173.6000031232834 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 175.99999904632568 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 175.20000040531158 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 175.35999417304993 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 186.71999871730804 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 186.71999871730804 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 187.83999979496002 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 186.88000738620758 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 199.0399956703186 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 199.8399943113327 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 200.00000298023224 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 200.95999538898468 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 211.84000372886658 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 213.76000344753265 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 213.919997215271 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 213.918998837471 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 225.11999309062958 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 225.9189933538437 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 225.43999552726746 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 226.23999416828156 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 224.31999444961548 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 238.87999355793 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 238.39999735355377 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 239.51999843120575 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 236.80000007152557 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 252.00000405311584 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 250.71999430656433 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 251.99949741363525 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 249.59999322891235 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 263.5200023651123 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 264.8000121116638 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 265.1199996471405 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 265.9189999103546 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 277.75999903678894 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 277.75898575782776 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 279.04000878334045 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 278.56001257896423 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 289.92000222206116 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 290.23998975753784 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 289.6000146865845 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 291.03949666023254 + } +} \ No newline at end of file diff --git a/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_default.json b/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_default.json new file mode 100644 index 0000000..9ff7a0e --- /dev/null +++ b/vllm/attention/backends/configs/QH=8_KVH=1_QKD=576_VD=512_fp16_default.json @@ -0,0 +1,1186 @@ +{ + "1": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 42.40100085735321 + }, + "100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.48099958896637 + }, + "400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 76.80100202560425 + }, + "700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 81.76100254058838 + }, + "1000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 85.76100319623947 + }, + "1300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 83.68100225925446 + }, + "1600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 85.44149994850159 + }, + "1900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 86.88099682331085 + }, + "2200": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 96.16100043058395 + }, + "2500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 96.96099907159805 + }, + "2800": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 99.36200082302094 + }, + "3100": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 98.56099635362625 + }, + "3400": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 99.20100122690201 + }, + "3700": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 100.64099729061127 + }, + "4000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 257.9230070114136 + }, + "4300": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.44150388240814 + }, + "4600": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.88100075721741 + }, + "4900": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 110.72099953889847 + }, + "5000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 109.44099724292755 + }, + "5500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 111.84199899435043 + }, + "6000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 112.32200264930725 + }, + "6500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 359.36400294303894 + }, + "7000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 360.164999961853 + }, + "7500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 362.8849983215332 + }, + "8000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 153.92200648784637 + }, + "8500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 133.92199575901031 + }, + "9000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 201.68299973011017 + }, + "9500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 134.88100469112396 + }, + "10000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 135.8419954776764 + }, + "10500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 178.24199795722961 + }, + "11000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 171.76198959350586 + }, + "11500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 146.7210054397583 + }, + "12000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 601.127028465271 + }, + "12500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 681.9279789924622 + }, + "13000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 284.00298953056335 + }, + "13500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 285.1240038871765 + }, + "14000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 282.803475856781 + }, + "14500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 280.80400824546814 + }, + "15000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 297.9240119457245 + }, + "15500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 280.80400824546814 + }, + "16000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 621.288001537323 + }, + "16500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 289.28399085998535 + }, + "17000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 290.4840111732483 + }, + "17500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 295.2040135860443 + }, + "18000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 296.644002199173 + }, + "18500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 290.88300466537476 + }, + "19000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 273.4430134296417 + }, + "19500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 316.88401103019714 + }, + "20000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 278.8830101490021 + }, + "20500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 558.8070154190063 + }, + "21000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 305.52399158477783 + }, + "21500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 298.08300733566284 + }, + "22000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 303.84400486946106 + }, + "22500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 309.2834949493408 + }, + "23000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 388.00498843193054 + }, + "23500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 361.76449060440063 + }, + "24000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 636.1680030822754 + }, + "24500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 357.76448249816895 + }, + "25000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 375.6850063800812 + }, + "25500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 362.964004278183 + }, + "26000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 362.0845079421997 + }, + "26500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 361.92500591278076 + }, + "27000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 367.52501130104065 + }, + "27500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 32, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 690.887987613678 + }, + "28000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 375.2039968967438 + }, + "28500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 376.32399797439575 + }, + "29000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 389.4439935684204 + }, + "29500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 387.20399141311646 + }, + "30000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 4 + } + }, + "best_us": 414.48551416397095 + }, + "30500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 4 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 383.20451974868774 + }, + "31000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 446.56500220298767 + }, + "31500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 627.2079944610596 + }, + "32000": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 2 + } + }, + "best_us": 454.00500297546387 + }, + "32500": { + "kernel_kind": "v2_tc", + "best_config": { + "stage1": { + "BLOCK_N": 16, + "BLOCK_DIM": 64, + "num_stages": 1, + "num_warps": 2 + }, + "stage2": { + "num_stages": 1, + "num_warps": 8 + } + }, + "best_us": 428.00599336624146 + } + } \ No newline at end of file diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py new file mode 100644 index 0000000..793cb87 --- /dev/null +++ b/vllm/attention/backends/cpu_mla.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm._custom_ops as ops +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState +from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + + +class CPUMLABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "CPU_MLA" + + @staticmethod + def get_metadata_cls() -> Type["CPUMLAMetadata"]: + return CPUMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: + return CPUMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_impl_cls() -> Type["CPUMLAImpl"]: + return CPUMLAImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +@dataclass +class CPUMLAMetadata(TorchSDPAMetadata): + # New for MLA + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor = None + + # required by MLACommonImpl + is_profile_run: bool = False + + +class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + assert not self.chunked_prefill, \ + "chunked prefill is currently not supported" + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # metadata for prefill + if input_data.num_prefills > 0: + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + + # for chunked-prefill + if self.chunked_prefill: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + prefill_block_tables = None + + else: + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + prefill_block_tables = None + + # metadata for decode + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + return CPUMLAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + prefill_query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + input_positions=torch.tensor([self.input_data.input_positions])) + + +class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "CPUMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CPUMLAImpl") + + # states is implemented. + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "CPUMLAImpl with FP8 KV cache not yet supported") + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim + v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) + + output = torch.empty_like(q) + ipex_ops.varlen_attention( + query=q, + key=k, + value=v_padded, + out=output, + seqlen_q=prefill_metadata.prefill_query_start_loc, + seqlen_k=prefill_metadata.prefill_query_start_loc, + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.max_query_len, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + logits_soft_cap=0.0, + window_size_left=-1, + window_size_right=-1, + alibi_slopes=None, + ) + + # remove padding + output = output.view(-1, self.num_heads, + q.shape[-1])[..., :v.shape[-1]] + return output.reshape(-1, self.num_heads * v.shape[-1]) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: CPUMLAMetadata, # type: ignore[override] + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1) + o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) + + # Run MQA + ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py new file mode 100644 index 0000000..a99d7d5 --- /dev/null +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -0,0 +1,1530 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with Dual chunk flash attention and sparse attention. +""" +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + FlashAttentionMetadataBuilder) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.utils import async_tensor_h2d +from vllm.platforms import current_platform +if not current_platform.is_rocm(): + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache, sparse_attn_func) +else: + from flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache, sparse_attn_func) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + +logger = init_logger(__name__) + + +class DualChunkFlashAttentionBackend(FlashAttentionBackend): + + accept_output_buffer: bool = False + + @staticmethod + def get_name() -> str: + return "DUAL_CHUNK_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: + return DualChunkFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: + return DualChunkFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: + return DualChunkFlashAttentionMetadataBuilder + + +@dataclass +class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): + # Block size of the paged kv cache. + block_size: int = 16 + + # Original max position embeddings. + original_max_position_embeddings: int = 0 + + # Chunk size + chunk_size: int = 8192 + + # Local size + local_size: int = 1024 + + # (batch_size,). The orig sequence length per sequence. + orig_seq_lens: Optional[List[int]] = None + + # orig_seq_lens stored as a tensor. + orig_seq_lens_tensor: Optional[torch.Tensor] = None + + # Length scaling factor + scaling_factor: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for intra attention. + seq_lens_intra: Optional[torch.Tensor] = None + + # Max sequence length for intra attention. + max_seq_len_intra: Optional[int] = None + + # (batch_size, num_blocks). Block table for intra attention. + block_tables_intra: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for succ attention. + seq_lens_succ: Optional[torch.Tensor] = None + + # Max sequence length for succ attention. + max_seq_len_succ: Optional[int] = None + + # (batch_size, num_blocks). Block table for succ attention. + block_tables_succ: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for inter attention. + seq_lens_inter: Optional[torch.Tensor] = None + + # Max sequence length for inter attention. + max_seq_len_inter: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "DualChunkFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + prefill_metadata = super().prefill_metadata + if prefill_metadata is None: + return None + + prefill_metadata = DualChunkFlashAttentionMetadata( + **prefill_metadata.asdict_zerocopy()) + + prefill_metadata.orig_seq_lens = ( + None if self.orig_seq_lens is None else + self.orig_seq_lens[:self.num_prefills]) + prefill_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[:self.num_prefills]) + + if self.original_max_position_embeddings > 0: + assert prefill_metadata.orig_seq_lens_tensor is not None + prefill_metadata.scaling_factor = ( + 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / + self.original_max_position_embeddings) + + 1.0).clip(min=1) + + self._cached_prefill_metadata = prefill_metadata + return prefill_metadata + + @property + def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + decode_metadata = super().decode_metadata + if decode_metadata is None: + return None + + decode_metadata = DualChunkFlashAttentionMetadata( + **decode_metadata.asdict_zerocopy()) + + decode_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[self.num_prefills:]) + + assert decode_metadata.orig_seq_lens_tensor is not None + assert decode_metadata.block_tables is not None + + cache_seq_lens = decode_metadata.orig_seq_lens_tensor + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + batch_size = decode_metadata.num_decode_tokens + + if self.original_max_position_embeddings > 0: + decode_metadata.scaling_factor = (0.1 * torch.log( + cache_seq_lens / self.original_max_position_embeddings) + + 1.0).clip(min=1) + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + decode_metadata.seq_lens_intra = seq_lens_intra + decode_metadata.max_seq_len_intra = max_seq_len_intra + + block_tables_intra = torch.zeros( + batch_size, + (max_seq_len_intra - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + st = chunk_num_curr[i] * chunk_len // self.block_size + ed = min( + st + (max_seq_len_intra - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_intra[i, :ed - + st] = decode_metadata.block_tables[i, st:ed] + decode_metadata.block_tables_intra = block_tables_intra + + seq_lens_succ = (chunk_num_curr - + (chunk_num_curr - 1).clip(min=0)) * chunk_len + max_seq_len_succ = seq_lens_succ.max().item() + decode_metadata.seq_lens_succ = seq_lens_succ + decode_metadata.max_seq_len_succ = max_seq_len_succ + if max_seq_len_succ: + block_tables_succ = torch.zeros( + batch_size, + (max_seq_len_succ - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // + self.block_size) + end = min( + start + (max_seq_len_succ - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_succ[ + i, :end - start] = decode_metadata.block_tables[i, + start:end] + decode_metadata.block_tables_succ = block_tables_succ + + seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + max_seq_len_inter = seq_lens_inter.max().item() + decode_metadata.seq_lens_inter = seq_lens_inter + decode_metadata.max_seq_len_inter = max_seq_len_inter + + self._cached_decode_metadata = decode_metadata + return decode_metadata + + +class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def prepare(self): + super().prepare() + self.orig_seq_lens: List[int] = [] + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + super()._add_seq_group(inter_data, chunked_prefill_enabled, + prefix_cache_hit) + for prompt_len, seq_len in zip(inter_data.prompt_lens, + inter_data.seq_lens): + self.orig_seq_lens.append(max(prompt_len, seq_len)) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + attn_metadata = super().build(seq_lens, query_lens, + cuda_graph_pad_size, batch_size) + attn_metadata = DualChunkFlashAttentionMetadata( + **attn_metadata.asdict_zerocopy()) + + device = self.runner.device + attn_metadata.orig_seq_lens = self.orig_seq_lens + attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( + self.orig_seq_lens, torch.int, device, self.runner.pin_memory) + + attn_metadata.block_size = self.runner.block_size + dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, + "dual_chunk_attention_config", {}) + attn_metadata.original_max_position_embeddings = \ + dual_chunk_attn_config.get("original_max_position_embeddings", 0) + attn_metadata.chunk_size = dual_chunk_attn_config.get( + "chunk_size", 8192) + attn_metadata.local_size = dual_chunk_attn_config.get( + "local_size", 1024) + + return attn_metadata + + +class DualChunkFlashAttentionImpl(FlashAttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + The prompts might have different lengths, while the generation tokens + always have length 1. + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + layer_idx: int = -1, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + DualChunkFlashAttentionBackend.get_supported_head_sizes()) + + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + assert dual_chunk_attention_config is not None + self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) + self.local_size = dual_chunk_attention_config.get("local_size", 1024) + self.original_max_position_embeddings = dual_chunk_attention_config.get( + "original_max_position_embeddings", 0) + self.sparse_attention_config = dual_chunk_attention_config.get( + "sparse_attention_config", None) + if not self.sparse_attention_config: + logger.warning_once("Sparse attention will not be enabled as " + "sparse attention config is not provided.") + self.sparse_attention_enabled = dual_chunk_attention_config.get( + "sparse_attention_enabled", self.sparse_attention_config + is not None) + self.sparse_attention_threshold = dual_chunk_attention_config.get( + "sparse_attention_threshold", 32768) + self.sparse_attention_last_q = dual_chunk_attention_config.get( + "sparse_attention_last_q", 64) + self.layer_idx = layer_idx + self.dual_chunk_attention_config = dual_chunk_attention_config + + if self.sparse_attention_config: + self.sparse_attention_config = { + int(i): j + for i, j in self.sparse_attention_config[ + self.layer_idx].items() + } + start_head = self.num_heads * get_tensor_model_parallel_rank() + end_head = start_head + self.num_heads + self.sparse_attention_config = [ + self.sparse_attention_config[i] + for i in range(start_head, end_head) + ] + + if self.sparse_attention_enabled: + self.arange = torch.arange(self.sparse_attention_last_q, + device="cuda") + self.last_q_mask = (self.arange[None, None, :, None] + >= self.arange[None, None, None, :]) + + def forward( # type: ignore + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DualChunkFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with DualChunkFlashAttention. + Args: + query: shape = [num_tokens, num_heads * head_size] + query_succ: shape = [num_tokens, num_heads * head_size] + query_inter: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is None, "Output tensor not supported for DualChunk" + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ) = torch.split(query, query.shape[-1] // 5, dim=-1) + + assert ( + query_succ is not None and query_inter is not None + ), "query_succ and query_inter are required in Dual Chunk Attention." + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view(-1, self.num_heads, + self.head_size) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.original_max_position_embeddings > 0: + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.scaling_factor is not None + assert prefill_meta.query_start_loc is not None + assert prefill_meta.orig_seq_lens is not None + current_start = 0 + query_start_loc_cpu = prefill_meta.query_start_loc.cpu() + for i in range(len(prefill_meta.orig_seq_lens)): + current_end = (current_start + + (query_start_loc_cpu[i + 1] - + query_start_loc_cpu[i]).item()) + key[current_start:current_end].mul_( + prefill_meta.scaling_factor[i]) + current_start = current_end + assert current_end <= attn_metadata.num_prefill_tokens + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + key[attn_metadata.num_prefill_tokens:].mul_( + scaling_factor.unsqueeze(-1).unsqueeze(-1)) + + if kv_cache is not None and kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + decode_query_succ = query_succ[num_prefill_tokens:] + decode_query_inter = query_inter[num_prefill_tokens:] + + # QKV for prefill. + query = query[:num_prefill_tokens] + query_succ = query_succ[:num_prefill_tokens] + query_inter = query_inter[:num_prefill_tokens] + query_succ_critical = query_succ_critical[:num_prefill_tokens] + query_inter_critical = query_inter_critical[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention, called during the profiling run. + out = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + assert prefill_meta.orig_seq_lens is not None + output[:num_prefill_tokens] = ( + self._dual_chunk_flash_attn_prefill( + q=query, + q_succ=query_succ, + q_inter=query_inter, + q_succ_critical=query_succ_critical, + q_inter_critical=query_inter_critical, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + orig_seq_lens=prefill_meta.orig_seq_lens, + scaling_factor=prefill_meta.scaling_factor, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1), + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + chunk_size=self.chunk_size, + local_size=self.local_size, + )) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = ( + self._dual_chunk_flash_attn_decoding( + decode_query.unsqueeze(1), + decode_query_succ.unsqueeze(1), + decode_query_inter.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + chunk_size=self.chunk_size, + local_size=self.local_size, + original_max_position_embeddings=self. + original_max_position_embeddings, + decode_meta=decode_meta, + ).squeeze(1)) + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + def _dual_chunk_flash_attn_prefill( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + orig_seq_lens: List[int], + scaling_factor: torch.Tensor, + softmax_scale: float, + causal: Optional[bool] = True, + window_size: Tuple[int, int] = (-1, -1), + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + chunk_size: int = 8192, + local_size: int = 1024, + ): + if alibi_slopes is not None: + raise ValueError( + "Dual Chunk Attention does not support alibi_slopes") + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + if window_size != (-1, -1): + raise ValueError( + "Dual Chunk Attention does not support window_size") + + cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() + cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() + all_outputs = [] + + for i in range(0, len(cu_seqlens_q_cpu) - 1): + qs = cu_seqlens_q_cpu[i] + qe = cu_seqlens_q_cpu[i:i + 2][-1] + ks = cu_seqlens_k_cpu[i] + ke = cu_seqlens_k_cpu[i:i + 2][-1] + + current_q = q[qs:qe] + current_q_succ = q_succ[qs:qe] + current_q_inter = q_inter[qs:qe] + current_q_succ_critical = q_succ_critical[qs:qe] + current_q_inter_critical = q_inter_critical[qs:qe] + + if block_table is None: + current_k = k[ks:ke] + current_v = v[ks:ke] + current_block_table = None + current_orig_seq_len = orig_seq_lens[i] + else: + current_block_table = block_table[i] + current_orig_seq_len = orig_seq_lens[i] + current_k = k + current_v = v + sparse_attn_enabled = (self.sparse_attention_enabled + and current_orig_seq_len + > self.sparse_attention_threshold) + + if current_q.shape[0] == 0: + continue + + if current_k.shape[0] == 0: + all_outputs.append( + torch.zeros( + (current_q.shape[0], current_q.shape[1], v.shape[2]), + device=q.device, + dtype=q.dtype, + )) + continue + + current_output = torch.empty_like(current_q) + group_size = int(current_q.size(-2) / current_k.size(-2)) + + if sparse_attn_enabled: + num_device_q_heads = current_q.size(-2) + heads_vertical_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + heads_slash_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + for head_id in range(current_q.size(-2)): + ( + ty, + vertical_size, + slash_size, + _, + ) = self.sparse_attention_config[head_id] + assert ty == "vertical_and_slash", "only support slash mode" + + if vertical_size == 30: + vertical_size += 100 + heads_vertical_size[head_id] = vertical_size + heads_slash_size[head_id] = slash_size + + current_output = self._dual_chunk_flash_attn_prefill_func( + current_q, # allheads + current_q_succ, + current_q_inter, + current_q_succ_critical, + current_q_inter_critical, + current_k, + current_v, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + heads_vertical_size=heads_vertical_size, + heads_slash_size=heads_slash_size, + group_size=group_size) + else: + for head_id in range(current_q.size(-2)): + # (seq_len, num_heads, head_size) + current_q_head = current_q[:, head_id, :].unsqueeze(1) + current_q_succ_head = \ + current_q_succ[:, head_id, :].unsqueeze(1) + current_q_inter_head = \ + current_q_inter[:, head_id, :].unsqueeze(1) + current_q_succ_head_critical = \ + current_q_succ_critical[:, head_id, :].unsqueeze(1) + current_q_inter_head_critical = \ + current_q_inter_critical[:, head_id, :].unsqueeze(1) + if block_table is not None: + current_k_head = current_k[..., head_id // + group_size, :].unsqueeze(2) + current_v_head = current_v[..., head_id // + group_size, :].unsqueeze(2) + + else: + current_k_head = current_k[:, head_id, :].unsqueeze(1) + current_v_head = current_v[:, head_id, :].unsqueeze(1) + + current_out = self._dual_chunk_flash_attn_prefill_func( + current_q_head, + current_q_succ_head, + current_q_inter_head, + current_q_succ_head_critical, + current_q_inter_head_critical, + current_k_head, + current_v_head, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + ) + current_output[:, head_id:head_id + 1, :] = current_out + all_outputs.append(current_output) + return torch.cat(all_outputs, dim=0) + + def _dual_chunk_flash_attn_prefill_func( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + block_table, + softmax_scale: float, + chunk_size: int, + local_size: int, + scaling_factor: float, + k_length: int, + sparse_attn_enabled: Optional[bool] = True, + heads_vertical_size=None, + heads_slash_size=None, + group_size=None, + ): + flash_results = [] + chunk_len = chunk_size - local_size + + if block_table is not None: + block_size = v.shape[1] + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + else: + block_size = 1 + + if self.original_max_position_embeddings > 0: + softmax_scale = softmax_scale * scaling_factor + + begin = k_length - q.shape[0] + while begin < k_length: + flash_per_chunk = [] + + prev_chunk_end_pos = (begin // chunk_len) * chunk_len + next_chunk_end_pos = prev_chunk_end_pos + chunk_len + end = min(next_chunk_end_pos, k_length) + qbegin = begin - (k_length - q.shape[0]) + qend = end - (k_length - q.shape[0]) + + qk_chunks = [] + q_states_intra = q[qbegin:qend] + # choose critical token + if block_table is not None: + block_tables_intra = _get_block(block_table, block_size, + prev_chunk_end_pos, end) + k_states_intra = k[block_tables_intra].view( + -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] + v_states_intra = v[block_tables_intra].view( + -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] + else: + block_tables_intra = None + k_states_intra = k[prev_chunk_end_pos:end] + v_states_intra = v[prev_chunk_end_pos:end] + + if sparse_attn_enabled: + last_q_size = min(qend - qbegin, self.sparse_attention_last_q) + _, num_device_k_heads, head_dim = k_states_intra.shape + k_states_intra = (k_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + v_states_intra = (v_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + qk_chunks.append( + (q_states_intra.transpose(0, 1)[:, -last_q_size:] * + softmax_scale) @ k_states_intra.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len >= 0: + q_states_succ = q_succ[qbegin:qend] + q_states_succ_critical = q_succ_critical[qbegin:qend] + if block_table is not None: + block_tables_succ = _get_block( + block_table, block_size, + prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) + k_states_succ = k[block_tables_succ].view( + -1, *k.shape[-2:])[:chunk_len] + v_states_succ = v[block_tables_succ].view( + -1, *v.shape[-2:])[:chunk_len] + else: + k_states_succ = k[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + v_states_succ = v[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + + if sparse_attn_enabled: + k_states_succ = (k_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_succ = (v_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_succ_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_succ.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + q_states_inter = q_inter[qbegin:qend] + q_states_inter_critical = q_inter_critical[qbegin:qend] + if block_table is not None: + block_tables_inter = _get_block( + block_table, block_size, 0, + prev_chunk_end_pos - chunk_len) + k_states_inter = k[block_tables_inter].view( + -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + v_states_inter = v[block_tables_inter].view( + -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + else: + k_states_inter = k[:prev_chunk_end_pos - chunk_len] + v_states_inter = v[:prev_chunk_end_pos - chunk_len] + + if sparse_attn_enabled: + k_states_inter = (k_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_inter = (v_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_inter_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_inter.permute(1, 2, 0)) + + if sparse_attn_enabled: + reversed_qk = qk_chunks[::-1] + qk = torch.cat(reversed_qk, dim=-1) + + qk[:, :, -last_q_size:] = torch.where( + self.last_q_mask[..., -last_q_size:, + -last_q_size:].to(qk.device), + qk[:, :, -last_q_size:], -torch.inf) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + + # Avoid sorting by using the min/max ints to fill the indexer + # buffers. + int32_max = torch.iinfo(torch.int32).max + int32_min = torch.iinfo(torch.int32).min + n_heads = qk.size()[0] + max_slash_topk = torch.max(heads_slash_size).item() + max_vertical_topk = torch.max(heads_vertical_size).item() + # store each head's slash topk, vertical topk + vertical = vertical.reshape((n_heads, -1)) + # prevent out of range when prompt size < max_vertical_topk + max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) + vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, + -1).indices + slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), + dtype=torch.int64, + device=qk.device) + for head_i in range(n_heads): + # (nqheads=1, lastq, k_len) + head_score = qk[head_i:head_i + 1, :, :] + slash_scores = _sum_all_diagonal_matrix(head_score) + if head_score.size(1) != 1: + # drop right up corner + slash_scores = slash_scores[..., :-last_q_size + 1] + slash_scores[..., -100:] = torch.inf + + head_slash_size = heads_slash_size[head_i] + head_slash_size = min(head_slash_size, vertical.size(-1)) + slash_topk = torch.topk(slash_scores, head_slash_size, + -1).indices + #(nheads, max_topk) + slash_topk_buffer[head_i, :head_slash_size] = slash_topk + + # reset heads topk + heads_slash_size[head_i] = head_slash_size + heads_vertical_size[head_i] = min( + heads_vertical_size[head_i], max_vertical_topk) + + # store + vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + succ_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + inter_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + inter_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + + vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + + for head_i in range(n_heads): + vertical_topk = vertical_topk_buffer[ + head_i, :heads_vertical_size[head_i]] + # intra + intra_vertical_indices = vertical_topk[ + vertical_topk >= + prev_chunk_end_pos] - prev_chunk_end_pos + if intra_vertical_indices.nelement() == 0: + intra_vertical_indices = torch.cat([ + intra_vertical_indices, + torch.arange(0, + k_states_intra.size(0), + max(1, + k_states_intra.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + slash_topk = slash_topk_buffer[ + head_i, :heads_slash_size[head_i]] + intra_slash_indices = ( + (qk.size(-1) - 1) - + slash_topk[slash_topk >= prev_chunk_end_pos]) + # fill buffer + v_count = intra_vertical_indices.nelement() + s_count = intra_slash_indices.nelement() + vertical_size_buffer[head_i] = v_count + slash_sizes_buffer[head_i] = s_count + vertical_buffer[head_i, :v_count].copy_( + intra_vertical_indices) + slash_buffer[head_i, :s_count].copy_(intra_slash_indices) + # succ + if prev_chunk_end_pos - chunk_len >= 0: + succ_vertical_indices = vertical_topk[ + (vertical_topk < prev_chunk_end_pos) + & (vertical_topk >= prev_chunk_end_pos - + chunk_len)] - (prev_chunk_end_pos - chunk_len) + # TODO: support no vertical + if succ_vertical_indices.nelement() == 0: + succ_vertical_indices = torch.cat([ + succ_vertical_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + succ_slash_indices = ( + (prev_chunk_end_pos + (qend - qbegin) - 1) - + slash_topk[((slash_topk >= + (prev_chunk_end_pos - chunk_len)) & + (slash_topk < (prev_chunk_end_pos + + (qend - qbegin))))]) + if succ_slash_indices.nelement() == 0: + succ_slash_indices = torch.cat([ + succ_slash_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = succ_vertical_indices.nelement() + s_count = succ_slash_indices.nelement() + succ_vertical_size_buffer[head_i] = v_count + succ_slash_sizes_buffer[head_i] = s_count + succ_vertical_buffer[head_i, :v_count].copy_( + succ_vertical_indices) + succ_slash_buffer[head_i, :s_count].copy_( + succ_slash_indices) + + if prev_chunk_end_pos - 2 * chunk_len >= 0: + inter_vertical_indices = vertical_topk[ + vertical_topk < prev_chunk_end_pos - chunk_len] + + if inter_vertical_indices.nelement() == 0: + inter_vertical_indices = torch.cat([ + inter_vertical_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + inter_slash_indices = ( + (prev_chunk_end_pos - chunk_len + + (qend - qbegin) - 1) - + slash_topk[slash_topk < (prev_chunk_end_pos - + chunk_len + + (qend - qbegin))]) + if inter_slash_indices.nelement() == 0: + inter_slash_indices = torch.cat([ + inter_slash_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = inter_vertical_indices.nelement() + s_count = inter_slash_indices.nelement() + inter_vertical_size_buffer[head_i] = v_count + inter_slash_sizes_buffer[head_i] = s_count + inter_vertical_buffer[head_i, :v_count].copy_( + inter_vertical_indices) + inter_slash_buffer[head_i, :s_count].copy_( + inter_slash_indices) + else: + intra_vertical_indices, intra_slash_indices = None, None + succ_vertical_indices, succ_slash_indices = None, None + inter_vertical_indices, inter_slash_indices = None, None + + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=vertical_buffer, + slash_indices=slash_buffer, + vertical_indices_count=vertical_size_buffer, + slash_indices_count=slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=intra_vertical_indices, + slash_indices=intra_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_buffer, + slash_indices=succ_slash_buffer, + vertical_indices_count=succ_vertical_size_buffer, + slash_indices_count=succ_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_indices, + slash_indices=succ_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_buffer, + slash_indices=inter_slash_buffer, + vertical_indices_count=inter_vertical_size_buffer, + slash_indices_count=inter_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_indices, + slash_indices=inter_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + flash_results.append(flash_per_chunk) + begin = end + + attn_output = self._merge_attn_outputs(flash_results) + del flash_results + return attn_output + + def _do_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: float, + causal: bool = True, + block_table: torch.Tensor = None, + max_seqlen_k: Optional[int] = None, + stage: str = "intra", + vertical_indices: Optional[torch.Tensor] = None, + slash_indices: Optional[torch.Tensor] = None, + vertical_indices_count: Optional[torch.Tensor] = None, + slash_indices_count: Optional[torch.Tensor] = None, + mergehead_softmax_scale: Optional[float] = None, + sparse_attn_enabled: Optional[bool] = False, + ): + if max_seqlen_k is None: + max_seqlen_k = key_states.shape[0] + + q_len = query_states.shape[0] + q_heads = query_states.shape[1] + h_dim = query_states.shape[-1] + + if sparse_attn_enabled: + assert slash_indices is not None + if stage == "intra": + assert causal + else: + assert not causal + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + + q = query_states + k = key_states + v = value_states + + if (vertical_indices_count is not None and \ + slash_indices_count is not None): + assert mergehead_softmax_scale is not None + + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + mergehead_softmax_scale, + causal=causal, + stage=stage, + vertical_indices_count=vertical_indices_count, + slash_indices_count=slash_indices_count) + res = res.view(q_heads, q_len, + h_dim).transpose(0, 1) # (qlen,nhead,h_dim) + s_lse = s_lse.view( + q_heads, q_len, + 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) + else: + res, s_lse = _vertical_slash_sparse_attention(q, + k, + v, + vertical_indices, + slash_indices, + softmax_scale, + causal=causal, + stage=stage) + res = res.view(q_len, q_heads, h_dim) + s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return res, s_lse + + if not current_platform.is_rocm(): + output, softmax_lse = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor([0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor([0, max_seqlen_k], + dtype=torch.int32, + device=query_states.device), + max_seqlen_k=max_seqlen_k, + causal=causal, + block_table=block_table.unsqueeze(0), + return_softmax_lse=True, + ) + else: + output, softmax_lse = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor([0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor([0, max_seqlen_k], + dtype=torch.int32, + device=query_states.device), + max_seqlen_k=max_seqlen_k, + causal=causal, + block_table=block_table.unsqueeze(0), + return_attn_probs=True, + ) + softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, + 2).float() + return output, softmax_lse + + def _merge_attn_outputs( + self, + flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], + return_lse: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs_all = [] + logits_all = [] + + for flash_per_chunk in flash_results: + if len(flash_per_chunk) == 1: + attn_outputs_all.append(flash_per_chunk[0][0]) + if return_lse: + logits_all.append(flash_per_chunk[0][1]) + continue + + attn_outputs = torch.stack([ + flash_attn_output[0] for flash_attn_output in flash_per_chunk + ]) + logits = torch.stack([ + flash_attn_output[1] for flash_attn_output in flash_per_chunk + ]) + logits = logits.to(torch.float32) + + if return_lse: + max_val = torch.max(logits, dim=0).values + diff = torch.abs(logits[0] - logits[1]) + log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) + logits_all.append(log_sum_exp) + + max_logits = torch.max(logits, dim=0).values + stable_logits = logits - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) + attn_outputs_all.append(attn_outputs.sum(dim=0)) + + if return_lse: + return (torch.cat(attn_outputs_all, + dim=0), torch.cat(logits_all, dim=-1)) + else: + return torch.cat(attn_outputs_all, dim=0) + + def _dual_chunk_flash_attn_decoding( + self, + query: torch.Tensor, + query_succ: torch.Tensor, + query_inter: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + chunk_size: int, + local_size: int, + original_max_position_embeddings: int, + decode_meta: DualChunkFlashAttentionMetadata, + ): + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + + block_size = value_cache.shape[1] + chunk_len = chunk_size - local_size + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + if original_max_position_embeddings > 0: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + query = (query * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) # possible for numerical issue, need to fused in the kernel + query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + outputs_list = [] + softmax_lses_list = [] + + # intra-attention + intra_output, intra_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query, + key_cache, + value_cache, + decode_meta.block_tables_intra, + decode_meta.seq_lens_intra, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(intra_output) + softmax_lses_list.append(intra_softmax_lse) + + # succ-attention + if decode_meta.max_seq_len_succ: + succ_output, succ_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_succ, + key_cache, + value_cache, + decode_meta.block_tables_succ, + decode_meta.seq_lens_succ, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(succ_output) + softmax_lses_list.append(succ_softmax_lse) + + # inter-attention + if decode_meta.max_seq_len_inter: + inter_output, inter_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_inter, + key_cache, + value_cache, + block_table[:, :decode_meta.max_seq_len_inter], + decode_meta.seq_lens_inter, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(inter_output) + softmax_lses_list.append(inter_softmax_lse) + outputs = torch.stack(outputs_list, dim=0) + del outputs_list + softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) + del softmax_lses_list + max_logits = torch.max(softmax_lses, dim=0).values + stable_logits = softmax_lses - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + outputs *= lse_s.unsqueeze(-1).transpose(2, 3) + return outputs.sum(0) + + def _dual_chunk_flash_attn_decoding_with_exp_sums( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + ): + out, softmax_lse = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + return_softmax_lse=True, + ) + mask = (cache_seqlens == 0) + out[mask] = 0 + softmax_lse[mask] = -float("inf") + return out, softmax_lse + + +def _vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + softmax_scale: float, + causal: bool = True, + stage: str = "intra", + block_size_M: int = 64, + block_size_N: int = 64, + vertical_indices_count: torch.Tensor = None, # [N_HEADS,] + slash_indices_count: torch.Tensor = None, +): + if stage == "intra": + assert causal + else: + assert not causal + + batch_size, num_heads, context_size, head_dim = query.shape + _, _, kv_seq_len, _ = key.shape + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + q_seqlens = torch.tensor([context_size], + dtype=torch.int32, + device=query.device) + kv_seqlens = torch.tensor([kv_seq_len], + dtype=torch.int32, + device=query.device) + + if vertical_indices_count is not None and slash_indices_count is not None: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes_mergehead( + q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, + causal) + else: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, + s_idx, context_size, + block_size_M, block_size_N, + causal) + + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + causal=causal, + softmax_scale=softmax_scale, + return_softmax_lse=True, + ) + out = out.transpose(1, 2).contiguous() + softmax_lse = lse.reshape(*lse.shape, 1) + return (out[..., :context_size, :head_dim], + softmax_lse[..., :context_size, :]) + + +def _sum_all_diagonal_matrix(mat: torch.tensor): + h, n, m = mat.shape + # Zero matrix used for padding + zero_mat = torch.zeros((h, n, n), device=mat.device) + # pads the matrix on left and right + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) + # Change the strides + mat_strided = mat_padded.as_strided((1, n, n + m), + (n * (2 * n + m), 2 * n + m + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 1) + return sum_diags[:, 1:] # drop left bottom corner + + +def _get_block(block_table: torch.Tensor, block_size: int, begin: int, + end: int): + begin_block = begin // block_size + end_block = (end - 1) // block_size + 1 + return block_table[begin_block:end_block] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py new file mode 100644 index 0000000..4077fbe --- /dev/null +++ b/vllm/attention/backends/flash_attn.py @@ -0,0 +1,1084 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with FlashAttention.""" +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.platforms import current_platform +if not current_platform.is_rocm(): + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) +else: + from flash_attn import (flash_attn_varlen_func, vllm_flash_attn_varlen_func, + flash_attn_with_kvcache) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + +class FlashAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int, + tree_attention_masks_tensor: Optional[torch.Tensor] = None): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + tree_attention_masks_tensor: attention mask used in tree style attention. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class FlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.vllm_flash_attn_version = get_flash_attn_version( + requires_alibi=self.alibi_slopes is not None) + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): + raise NotImplementedError( + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashAttentionImpl") + + # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. + if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: + assert ( + layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( + "key/v_scale is only supported in FlashAttention 3 with " + "base dtype bfloat16") + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + fp8_attention = kv_cache_dtype.startswith("fp8") + + if fp8_attention and not flash_attn_supports_fp8(): + raise NotImplementedError( + "FlashAttention does not support FP8 kv-cache on this device.") + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if fp8_attention: + kv_cache = kv_cache.view(torch.float8_e4m3fn) + key_cache = key_cache.view(torch.float8_e4m3fn) + value_cache = value_cache.view(torch.float8_e4m3fn) + + if fp8_attention: + num_tokens, num_heads, head_size = query.shape + query, _ = ops.scaled_fp8_quant( + query.reshape( + (num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale) + query = query.reshape((num_tokens, num_heads, head_size)) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + if fp8_attention: + num_kv_tokens, num_kv_heads, head_size = key.shape + + key, _ = ops.scaled_fp8_quant( + key.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._k_scale) + key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) + + value, _ = ops.scaled_fp8_quant( + value.reshape((num_kv_tokens, + num_kv_heads * head_size)).contiguous(), + layer._v_scale) + value = value.reshape( + (num_kv_tokens, num_kv_heads, head_size)) + + descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) + if not current_platform.is_rocm(): + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + assert prefill_meta.query_start_loc is not None + max_seq_len = max(prefill_meta.seq_lens) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + if not current_platform.is_rocm(): + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens_tensor, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + vllm_flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens_tensor, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=prefill_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + assert decode_meta.query_start_loc is not None + descale_shape = (decode_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + if not current_platform.is_rocm(): + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + seqused_k=decode_meta.seq_lens_tensor, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + out=decode_output, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + decode_output = flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + seqused_k=decode_meta.seq_lens_tensor, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) + if not current_platform.is_rocm(): + flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=decode_output.unsqueeze(1), + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + else: + decode_output = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + return output + + +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: str) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) \ No newline at end of file diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000..c672787 --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,1109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type + +from vllm.multimodal import MultiModalPlaceholderMap + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + + from vllm.vllm_flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + # Avoid turning these types into variables during type checking + if not TYPE_CHECKING: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.layer import Attention +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD" + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "FLASHINFER" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> Tuple[int, ...]: + cache_layout = FLASHINFER_KV_CACHE_LAYOUT + assert (cache_layout in ("NHD", "HND")) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, + 2, 4) + return stride_order + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig) -> Dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: Dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, FlashInferImpl) + + # Infer hyperparameters from the attention layer + window_size = impl.sliding_window + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = impl.logits_soft_cap + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: Dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = self.runner.vllm_config + self._kv_cache_layout = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def get_kv_cache_layout(self): + if self._kv_cache_layout is None: + self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT + return self._kv_cache_layout + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + self._get_workspace_buffer(), self.get_kv_cache_layout()) + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + self.get_kv_cache_layout(), + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, + self.get_kv_cache_layout(), + use_tensor_cores) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + global_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None, + **dataclasses.asdict(global_params), + ) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + # In case of multistep chunked-prefill, there might be prefill requests + # scheduled while CUDA graph mode is enabled. We don't run graph in that + # case. + if use_cuda_graph and is_decode: + if model_input.inputs_embeds is None: + batch_size = model_input.input_tokens.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, False)].attn_state) + else: + batch_size = model_input.inputs_embeds.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, True)].attn_state) + + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 + + use_cuda_graph: bool = True + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + # FlashInfer 0.2 encourages passing host tensors + device: torch.device = torch.device("cpu") + is_profile_run: bool = False + + # The FlashInfer backend currently supports only models in which all layers + # share the same following hyperparameters: + + # The left (inclusive) window size for the attention window, when + # set to `-1`, the window size will be set to the full length of + # the sequence. Defaults to `-1`. + window_left: int = -1 + # The attention logits soft capping value (used in Gemini, Grok and + # Gemma-2, etc.), if not provided, will be set to `0`. If greater + # than 0, the logits will be capped according to formula: + # $$\texttt{logits\_soft\_cap} \times + # \mathrm{tanh}(x / \texttt{logits\_soft\_cap})$$, + # where $x$ is the input logits. + logits_soft_cap: Optional[float] = None + # The scale used in softmax, if not provided, will be set to + # `1.0 / sqrt(head_dim)`. + sm_scale: Optional[float] = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + def begin_forward(self): + if self.num_prefill_tokens > 0: + if self.paged_kv_indices is None: + return + + assert self.prefill_wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None + self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # We will use flash attention for profiling to + # determine the number of blocks. Therefore, + # we don't need to prepare the input for flashinfer for profile run. + if not self.is_profile_run: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.prefill_wrapper.plan( + self.query_start_loc, + self.paged_kv_indptr[:self.num_prefills + 1], + self.paged_kv_indices, + self.paged_kv_last_page_len[:self.num_prefills], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.data_type) + if self.num_decode_tokens > 0: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + self.decode_wrapper.plan( + self.paged_kv_indptr[self.num_prefills:], + self.paged_kv_indices, + self.paged_kv_last_page_len[self.num_prefills:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + # kv-cache data type. + kv_data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_prefills == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + + +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + # Global hyperparameters shared by all attention layers + self.global_hyperparameters: Optional[PerLayerParameters] = None + + self.vllm_config = self.runner.vllm_config + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + if self.global_hyperparameters is None: + # Infer global hyperparameters, since currently we only support + # models in which all layers share the same values for the + # following hyperparameters: + # - `window_left` + # - `logits_soft_cap` + # - `sm_scale` + inferred_params = infer_global_hyperparameters( + get_per_layer_parameters(self.vllm_config)) + self.global_hyperparameters = inferred_params + self.window_left = inferred_params.window_left + self.logits_soft_cap = inferred_params.logits_soft_cap + self.sm_scale = inferred_params.sm_scale + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int, + tree_attention_masks_tensor: Optional[torch.Tensor] = None): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + tree_attention_masks_tensor: attention mask used in tree style attention. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMetadata( + decode_query_len=decode_query_len, + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + sm_scale=self.sm_scale, + ) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in FlashInfer is not supported yet, it will fall" + " back to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for FlashInferImpl") + + # TODO: directly write to output tensor + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + stride_order = FlashInferBackend.get_kv_cache_stride_order() + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + + assert prefill_meta.prefill_wrapper._causal + assert prefill_meta.prefill_wrapper._window_left == window_left + assert prefill_meta.prefill_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert prefill_meta.prefill_wrapper._sm_scale == softmax_scale + + prefill_output = prefill_meta.prefill_wrapper.run( + query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + + assert decode_meta.decode_wrapper._window_left == window_left + assert decode_meta.decode_wrapper._logits_soft_cap == ( + logits_soft_cap or 0.0) + assert decode_meta.decode_wrapper._sm_scale == softmax_scale + + decode_output = decode_meta.decode_wrapper.run( + decode_query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py new file mode 100644 index 0000000..19f8cfd --- /dev/null +++ b/vllm/attention/backends/flashmla.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ + get_mla_metadata( + m.seq_lens_tensor[m.num_prefills:], + self.num_q_heads, + 1, # MQA for the decode path + ) + + return m + + +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.num_q_heads, + 1, # MQA for the decode path + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert metadata.num_decode_tokens > 0 + + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + self._graph_seq_lens[:batch_size], + self.num_q_heads, + 1, # MQA for the decode path + ) + + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits) + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str] = None, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + if self.kv_cache_dtype != "fp8": + raise NotImplementedError( + "FlashMLA with other KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + k_scale = None, + kv_cache_dtype = "auto", + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + k_scale = k_scale, + kv_cache_dtype = kv_cache_dtype, + ) + + return self._v_up_proj(o) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py new file mode 100644 index 0000000..bf778a1 --- /dev/null +++ b/vllm/attention/backends/hpu_attn.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import vllm_hpu_extension.kernels as kernels +import vllm_hpu_extension.ops as ops +from vllm_hpu_extension.flags import enabled_flags +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, + HPUPagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class HPUAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "HPU_ATTN" + + @staticmethod + def get_impl_cls() -> Type["HPUAttentionImpl"]: + return HPUAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return HPUAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dsts: torch.Tensor, + ) -> None: + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dsts: torch.Tensor, + ) -> None: + HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) + + +@dataclass +class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): + """Metadata for HPUAttentionbackend.""" + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + attn_bias: Optional[torch.Tensor] + seq_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] + + +class HPUAttentionImpl(AttentionImpl, torch.nn.Module): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + max_seq_len: int = 4096, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + super(AttentionImpl, self).__init__() + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in HPU is not supported yet, it will fall back " + "to global attention for long context.") + self.kv_cache_dtype = kv_cache_dtype + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.batch2block_matmul = Matmul() + self.block2batch_matmul = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() + self.fused_scaled_dot_product_attention = kernels.fsdpa() + + self.prefill_impl = 'naive' + if "flex_attention" in enabled_flags(): + self.prefill_impl = 'flex' + if "fsdpa" in enabled_flags(): + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + self.prefill_impl = 'fsdpa' + + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + self.alibi_slopes = alibi_slopes + if alibi_slopes is not None: + alibi_slopes_tensor = torch.tensor(alibi_slopes, + dtype=torch.bfloat16) + self.alibi_slopes = alibi_slopes_tensor + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if self.prefill_impl == 'fsdpa': + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + + supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.attn_type = attn_type + if self.attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "HPUAttention with FP8 KV cache not yet supported") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: HPUAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for HPUAttentionImpl") + + batch_size, seq_len, hidden_size = query.shape + _, seq_len_kv, _ = key.shape + + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + block_indices = attn_metadata.block_indices + block_offsets = attn_metadata.block_offsets + key_cache = None + value_cache = None + if attn_metadata.is_prompt and self.attn_type \ + is not AttentionType.ENCODER_ONLY: + key = key.unflatten(0, (block_indices.size(0), -1)) + value = value.unflatten(0, (block_indices.size(0), -1)) + if kv_cache is not None and isinstance(kv_cache, tuple): + key_cache, value_cache = HPUPagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + key_cache = self.k_cache(key, key_cache, block_indices, + block_offsets) + value_cache = self.v_cache(value, value_cache, block_indices, + block_offsets) + + if attn_metadata.is_prompt: + # Prompt run. + query_shape = (batch_size, seq_len, self.num_heads, self.head_size) + kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, + self.head_size) + + attn_bias = attn_metadata.attn_bias + if attn_bias is not None and self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + + block_list = attn_metadata.block_list if attn_metadata \ + and attn_metadata.block_list is not None else None + + out = ops.prompt_attention( + impl=self.prefill_impl, + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + is_causal=True, + attn_bias=attn_bias, + valid_seq_lengths=attn_metadata.seq_lens_tensor, + **self.common_attention_args(block_list, key_cache, + value_cache)) + output = out.reshape(batch_size, seq_len, hidden_size) + else: + # Decoding run. + output = HPUPagedAttention.forward_decode( + query=query, + block_mapping=attn_metadata.block_mapping, + block_bias=attn_metadata.attn_bias, + block_groups=attn_metadata.block_groups, + **self.common_attention_args(attn_metadata.block_list, + key_cache, value_cache)) + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + + def common_attention_args(self, + block_list=None, + key_cache=None, + value_cache=None): + fsdpa_op = self.fused_scaled_dot_product_attention.apply \ + if self.fused_scaled_dot_product_attention is not None else None + return { + 'scale': self.scale, + 'matmul_qk_op': self.matmul_qk, + 'matmul_av_op': self.matmul_av, + 'batch2block_matmul_op': self.batch2block_matmul, + 'block2batch_matmul_op': self.block2batch_matmul, + 'fsdpa_op': fsdpa_op, + 'keys_fetch_func': self.k_cache.fetch_from_cache, + 'values_fetch_func': self.v_cache.fetch_from_cache, + 'softmax_op': self.softmax, + 'block_list': block_list, + 'key_cache': key_cache, + 'value_cache': value_cache, + } + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_len: int, +) -> torch.Tensor: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return bias diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py new file mode 100644 index 0000000..410ada3 --- /dev/null +++ b/vllm/attention/backends/ipex_attn.py @@ -0,0 +1,403 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "IPEX" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in Ipex is not supported yet, it will fall" + " back to global attention for long context.") + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.sliding_window is not None) + if logits_soft_cap is None: + logits_soft_cap = -1 + self.logits_soft_cap = logits_soft_cap + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IpexAttnMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for IpexAttentionImpl") + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention( + query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + self.alibi_slopes, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None, + window_size_left=-1, + window_size_right=-1, + logits_soft_cap=self.logits_soft_cap, + ) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/mla/__init__.py b/vllm/attention/backends/mla/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py new file mode 100644 index 0000000..a4507f2 --- /dev/null +++ b/vllm/attention/backends/mla/common.py @@ -0,0 +1,1405 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +# MLA Common Components + +This file implements common components for MLA implementations. + +First we define: + +Sq as Q sequence length +Skv as KV sequence length + +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). + +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably +tune. + +Main reference: DeepseekV2 paper, and FlashInfer Implementation +(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + +Deepseek's MLA attention works the following way: +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a +multi-head attention, while the compute is similar to multi-query attention. + +Below is example of both paths assuming batchsize = 1 + +## More Extent Definitions: + +C Context length, `Skv - Sq` +H hidden size +N number of attention heads +Lq latent dimension for Q 1536 in DSV3 +Lkv latent dimension for K/V 512 in DSV3 +P nope dimension, no rope. 128 in DSV3 +R rope dimension, goes through rope. 64 in DSV3 +V V head dim. 128 in DSV3 + +## Vector/Matrix Definitions + +h_t hidden states (input to attention) shape [Sq, H] +q_c latent/compressed Q shape [Sq, Lq] +q_nope uncompressed Q (no-rope) shape [Sq, N, P] +q_pe uncompressed Q (rope) shape [Sq, N, R] +kv_c latent/compressed KV shape [Skv, Lkv] +k_pe decoupled k position embeddings shape [Skv, R] +new_kv_c new kv_c from current iter shape [Sq, Lkv] +new_k_pe new k_pe from current iter shape [Sq, R] +cache_kv_c cached k_c from previous iters shape [C, Lkv] +cache_k_pe cached k_pe from previous iters shape [C, R] +W_DQ project h_t to q_c shape [H, Lq] +W_UQ project q_c to q_nope shape [Lq, N * P] +W_QR project q_c to q_pe shape [Lq, N * R] +W_DKV project h_t to kv_c shape [H, Lkv] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] +W_O project v to h_t shape [N * V, H] + + +## Compute Friendly Approach (i.e. "_forward_prefill"): + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) + +// MHA with QK headdim = P + R +// V headdim = V +// spda_o shape [Sq, N, V] +spda_o = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + v +) +return spda_o @ W_O + +NOTE: in the actual code, + `kv_b_proj` is [W_UK; W_UV] concatenated per head + `q_b_proj` is [W_UQ; W_QR] concatenated per head + `out_proj` is W_O + + +## Data-Movement Friendly Approach (i.e. "_forward_decode"): + +Runtime +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) +k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) + +// MQA with QK headdim = Lkv + R +// V headdim = Lkv +// spda_o shape [Sq, N, Lkv] +// NOTE: this is less compute-friendly since Lkv > P +// but is more data-movement friendly since its MQA vs MHA +spda_o = scaled_dot_product_attention( + torch.cat([ql_nope, q_pe], dim=-1), + torch.cat([kv_c, k_pe], dim=-1), + kv_c +) + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O + + +## Chunked Prefill + +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +the data-movement friendly approach if the chunk (i.e. `Sq`) is small. + +However, the compute-friendly approach can potentially run out of memory if Skv +is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` + +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +fixed workspace size. + +The chunked prefill approach is as follows: + +MCC Max chunk of context to process per iter, computed dynamically, + used to bound the memory usage + +q_c = h_t @ W_DQ +q_nope = (q_c @ W_UQ).view(Sq, N, P) +q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) +new_kv_c = h_t @ W_DKV +new_k_pe = RoPE(h_t @ W_KR) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) + +// MHA between queries and new KV +// with QK headdim = P + R +// V headdim = V +// curr_o shape [Sq, N, V] +// curr_lse shape [N, Sq], this is just order FA returns +curr_o, curr_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), + new_v, + casual=True, + return_softmax_lse=True +) + +// Compute attention with the already existing context +for chunk_idx in range(cdiv(C, MCC)): + chunk_start = chunk_idx * MCC + chunk_end = min(chunk_start + MCC, C) + Sc = chunk_end - chunk_start + cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] + cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] + cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) + cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) + + chunk_o, chunk_lse = scaled_dot_product_attention( + torch.cat([q_nope, q_pe], dim=-1), + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + dim=-1), + cache_v_chunk, + casual=False, + return_softmax_lse=True + ) + + curr_o, curr_lse = merge_attn_states( + suffix_output=curr_o, + suffix_lse=curr_lse, + prefix_output=chunk_o, + prefix_lse=chunk_lse, + ) + +return curr_o @ W_O +""" + +import functools +from abc import abstractmethod +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar) + +import torch +import os +from vllm import _custom_ops as ops +from vllm import envs +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, MLAAttentionImpl) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod) +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON +from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down + +if HAS_TRITON: + from vllm.attention.ops.triton_flash_attention import triton_attention +else: + triton_attention = None + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True +except ImportError: + is_vllm_fa = False + try: + # For rocm use upstream flash attention + from flash_attn import flash_attn_varlen_func + except ImportError: + flash_attn_varlen_func = None + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +is_hip = current_platform.is_rocm() + + +class MLACommonBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return MLACommonMetadata + + @staticmethod + def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: + return MLACommonMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["MLACommonState"]: + return MLACommonState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + ops.copy_blocks_mla(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [576] + + +T = TypeVar("T", bound="MLACommonMetadata") + + +class MLACommonState(AttentionState, Generic[T]): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + + scheduler_config = runner.scheduler_config + self.model_config = runner.model_config + cache_config = runner.cache_config + + self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = cache_config.enable_prefix_caching + + if self.chunked_prefill_enabled or self.enable_prefix_caching: + self.context_chunk_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max( + 8 * self.model_config.max_model_len, 4 * + scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.context_chunk_workspace_size >= \ + scheduler_config.max_num_seqs * cache_config.block_size + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + self._positions = torch.zeros((max_batch_size, ), + dtype=torch.long, + device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._positions + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + assert self._is_graph_capturing + + attn_metadata = self.runner.attn_backend.make_metadata( + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + use_cuda_graph=True, + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + head_dim=self.runner.model_config.get_head_size()) + + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + raise NotImplementedError( + "MLACommonState does not support encoder/decoder yet") + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + raise NotImplementedError( + "TritonMLAState does not support encoder/decoder yet") + + def begin_forward(self, model_input): + if self.chunked_prefill_enabled or self.enable_prefix_caching: + if not hasattr(self, "context_chunk_workspace"): + # not self.runner.device does not return the correct device + # for this process, (init_device sets the correct device but + # only on the Worker). The only way Ive figured out to get the + # correct device is to allocate the workspace on the first call + # to begin_forward and use the device of the input tokens + assert model_input.input_tokens is not None + self.context_chunk_workspace = torch.empty( + (self.context_chunk_workspace_size, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=model_input.input_tokens.device, + ) + + model_input.attn_metadata.context_chunk_workspace = \ + self.context_chunk_workspace + + +@dataclass +class MLACommonMetadata(AttentionMetadata): + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional[Any] = None + _cached_decode_metadata: Optional[Any] = None + + num_prefill_tokens: int + + # The dimension of the attention heads + head_dim: Optional[int] = None + + # Used when chunked prefill is enabled to simulate worst case workspace + # allocations, hopefully to avoid going OOM + is_profile_run: bool = False + + # New for MLA (compared to FlashAttention) + # For chunked prefill + context_chunk_cu_seq_lens: Optional[torch.Tensor] = None + context_chunk_starts: Optional[torch.Tensor] = None + context_chunk_seq_tot: Optional[List[int]] = None + context_chunk_max_seq_lens: Optional[List[int]] = None + # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted + context_chunk_workspace: Optional[torch.Tensor] = None + + def __post_init__(self): + supported_head_sizes = MLACommonBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f" received {self.head_dim}.") + + @property + def prefill_metadata(self): + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = self.__class__( + # Required by ModelRunner + use_cuda_graph=False, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, + context_chunk_starts=self.context_chunk_starts, + context_chunk_seq_tot=self.context_chunk_seq_tot, + context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + self._cached_decode_metadata = self.__class__( + # Required by ModelRunner + use_cuda_graph=self.use_cuda_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + head_dim=self.head_dim, + is_profile_run=self.is_profile_run) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + self._ops_advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions) + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + # here we use advance_step_flashinfo to update the paged_kv_* tensors + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + BLOCK_TABLE_EXTENDER: list[list[int]] = [] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.chunked_prefill_enabled = \ + self.runner.scheduler_config.chunked_prefill_enabled + self.enable_prefix_caching = \ + self.runner.cache_config.enable_prefix_caching + + if self.chunked_prefill_enabled or self.enable_prefix_caching: + attn_state = self.input_builder.runner.attn_state + self.context_chunk_workspace_size = \ + attn_state.context_chunk_workspace_size + self.page_size = self.runner.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * + cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + context_chunk_cu_seq_lens = None + context_chunk_starts = None + context_chunk_seq_tot = None + context_chunk_max_seq_lens = None + + if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ + and self.num_prefills > 0 \ + and context_lens_tensor is not None \ + and context_lens_tensor[:self.num_prefills].max() > 0: + + # NOTE: it is recommend you read the `Chunked Prefill` section in + # the comment at the top of the file before trying to understand + # the following code + + num_prefills_with_context = \ + (context_lens_tensor[:self.num_prefills] > 0).sum().item() + + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.context_chunk_workspace_size // num_prefills_with_context + + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, self.page_size) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + context_chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32)\ + .unsqueeze(1).expand(-1, self.num_prefills)\ + * max_context_chunk + chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ + .unsqueeze(0), context_chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) + _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ + .unsqueeze(-1) + context_chunk_cu_seq_lens = \ + torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) + context_chunk_max_seq_lens = \ + chunk_seq_lens.max(dim=1).values.tolist() + context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() + assert max(context_chunk_seq_tot) <= \ + self.context_chunk_workspace_size + + return self.runner.attn_backend.make_metadata( + # Required by ModelRunner + use_cuda_graph=use_captured_graph, # Not Attention Related + # Required by Attention Metadata + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + # Required by Attention Metadata (not used) + multi_modal_placeholder_index_maps=None, # Not Attention Related + enable_kv_scales_calculation=False, + # MLACommonMetadata + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + head_dim=self.runner.model_config.get_head_size(), + is_profile_run=self.runner.in_profile_run, + # MLACommonMetadata Chunk prefill specific + context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, + context_chunk_starts=context_chunk_starts, + context_chunk_seq_tot=context_chunk_seq_tot, + context_chunk_max_seq_lens=context_chunk_max_seq_lens, + ) + + +class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + q_lora_rank: Optional[int], + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + qk_head_dim: int, + v_head_dim: int, + kv_b_proj: ColumnParallelLinear, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing not supported in V0.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_head_dim + self.v_head_dim = v_head_dim + self.kv_b_proj = kv_b_proj + + self.triton_fa_func = triton_attention + # Handle the differences between the flash_attn_varlen from flash_attn + # and the one from vllm_flash_attn. The former is used on RoCM and the + # latter has an additional parameter to control FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + if not current_platform.is_rocm(): + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + else: + self._pad_v = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 + + def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, + return_softmax_lse, **kwargs): + maybe_padded_v = v + if self._pad_v: + # maybe_padded_v = torch.nn.functional.pad( + # v, [0, q.shape[-1] - v.shape[-1]], value=0) + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]- 32], value=0) + maybe_padded_v = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2]) + + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ + and not return_softmax_lse: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + None, # output + kwargs["cu_seqlens_q"], + kwargs["cu_seqlens_k"], + kwargs["max_seqlen_q"], + kwargs["max_seqlen_k"], + kwargs["causal"], + softmax_scale, + None, # bias + ) + elif is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v = maybe_padded_v, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + assert rest is not None + return attn_out, rest[0] + return attn_out + + def _v_up_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight if not envs.VLLM_USE_NN else layer.weight.T + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod): + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) + else: + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _compute_prefill_context( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ): + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + assert prefill_metadata.context_chunk_seq_tot is not None + assert prefill_metadata.context_chunk_cu_seq_lens is not None + assert prefill_metadata.context_chunk_starts is not None + assert prefill_metadata.context_chunk_max_seq_lens is not None + assert prefill_metadata.context_lens_tensor is not None + + output = None + iters = len(prefill_metadata.context_chunk_seq_tot) + + # Fetch from attn_metadata directly, since it late bound by + # MLAAttentionState, grabbing it directly `attn_metadata` can avoid + # any weirdness around prefill_metadata caching + assert attn_metadata.context_chunk_workspace is not None + workspace = attn_metadata.context_chunk_workspace + + for i in range(iters): + toks = prefill_metadata.context_chunk_seq_tot[i] + + ops.gather_cache( + src_cache=kv_c_and_k_pe_cache, + dst=workspace, + block_table=prefill_metadata.block_tables, + cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], + batch_size=prefill_metadata.num_prefills, + seq_starts=prefill_metadata.context_chunk_starts[i], + ) + + kv_c_normed = workspace[:toks]\ + [..., :self.kv_lora_rank] + k_pe = workspace[:toks]\ + [..., self.kv_lora_rank:].unsqueeze(1) + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) + + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + if output is None: + output = attn_output + output_lse = attn_softmax_lse + else: + output_tmp = torch.empty_like(output) + output_lse_tmp = torch.empty_like(output_lse) + merge_attn_states( + output=output_tmp, + output_lse=output_lse_tmp, + prefix_output=output, + prefix_lse=output_lse, + suffix_output=attn_output, + suffix_lse=attn_softmax_lse, + ) + output = output_tmp + output_lse = output_lse_tmp + + return output, output_lse + + def _forward_prefill( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + + prefill_metadata = attn_metadata.prefill_metadata + assert prefill_metadata is not None + + if envs.VLLM_HAS_CONTEXT_DEFAULT: + has_context = prefill_metadata.context_lens_tensor is not None \ + and prefill_metadata.context_lens_tensor.max() > 0 + else: + has_context = False + + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + + output = self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) + + if has_context: + # ROCm flash_attn_varlen_func will return 3 objects instead of 2 + suffix_output, suffix_lse = output + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, attn_metadata) + + output = torch.empty_like(suffix_output) + merge_attn_states( + output=output, + prefix_output=context_output, + prefix_lse=context_lse, + suffix_output=suffix_output, + suffix_lse=suffix_lse, + ) + + # unpad if necessary + if self._pad_v: + output = output[..., :v.shape[-1]] + + return output.flatten(start_dim=-2) + + @abstractmethod + def _forward_decode( + self, + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, # query in unified attn + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if output is not None: + raise NotImplementedError( + "output is not yet supported for MLAImplBase") + + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLAImplBase") + + if attn_metadata.is_profile_run and \ + attn_metadata.context_chunk_workspace is not None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + (attn_metadata.context_chunk_workspace.shape[0], + self.num_heads, self.qk_nope_head_dim + self.v_head_dim), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + has_decode = attn_metadata.decode_metadata is not None + has_prefill = attn_metadata.prefill_metadata is not None + + num_prefill_tokens: int = attn_metadata.num_prefill_tokens + q = q.view(-1, self.num_heads, self.qk_head_dim) + + decode_q = q[num_prefill_tokens:] + + prefill_q = q[:num_prefill_tokens] + prefill_k_pe = k_pe[:num_prefill_tokens] + prefill_k_c_normed = k_c_normed[:num_prefill_tokens] + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + output = torch.empty(attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens, + self.v_head_dim * self.num_heads, + device=q.device, + dtype=q.dtype) + if has_prefill: + output[:num_prefill_tokens] = self._forward_prefill( + prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, + attn_metadata) + + if has_decode: + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + + output[num_prefill_tokens:] = self._forward_decode( + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype) + + return output \ No newline at end of file diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py new file mode 100644 index 0000000..c900666 --- /dev/null +++ b/vllm/attention/backends/pallas.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "PALLAS" + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in Pallas is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.logits_soft_cap = logits_soft_cap + if head_size % 128 != 0: + raise NotImplementedError( + f"Head size must be a multiple of 128, found {head_size}.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for PallasAttentionImpl") + + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, + dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + query = query.squeeze(dim=1) + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + else: + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, + ) + output[chunk_start:chunk_end] = chunk_output + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], + *, + attn_logits_soft_cap: Optional[float], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + return torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + attn_logits_soft_cap=attn_logits_soft_cap, + ) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py new file mode 100644 index 0000000..820ddca --- /dev/null +++ b/vllm/attention/backends/placeholder_attn.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) +from vllm.utils import async_tensor_h2d + +# Placeholder attention backend for models like Mamba and pooling models that +# lack attention. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "NO_ATTENTION" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # Placeholder. + block_tables: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=0, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert not turn_prefills_into_decodes, \ + ("Multi-Step + Chunked-Prefill is not supported for attention-free" + "models. turn_prefills_into_decodes is a " + "Multi-Step + Chunked-Prefill specific parameter.") + + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + # Update sequences, masking off entries greater than num_queries + device = self.seq_lens_tensor.device + mask = torch.arange(self.seq_lens_tensor.size(0), + device=device) < num_queries + self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) + if sampled_token_ids is not None: + model_input.input_tokens.masked_scatter_( + mask, sampled_token_ids[:num_queries]) + + +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + + self.input_builder = input_builder + self.runner = input_builder.runner + + def prepare(self): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + + # Some input builders such as ModelInputForCPUBuilder do not have the + # "inter_data_list" attribute. + # Let's check inter_data_list exists before we reference it. + if hasattr(self.input_builder, "inter_data_list"): + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + num_decode_tokens = batch_size - self.num_prefill_tokens + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + # Placeholders + slot_mapping_tensor = torch.empty(0) + block_tables = torch.empty(0) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py new file mode 100644 index 0000000..1edf343 --- /dev/null +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Type, Union + +import torch + +import vllm._custom_ops as ops +import vllm.envs as envs +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.backends.utils import (compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, + get_aiter_mla_metadata) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA" + + @staticmethod + def get_impl_cls() -> Type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["AiterMLAState"]: + return AiterMLAState + + +@dataclass +class AiterMLAMetadata(MLACommonMetadata): + # The following 5 tensors are for current version of AITER MLA + block_table_bound: Optional[torch.Tensor] = None + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_lens: Optional[torch.Tensor] = None + + # This is just to make new AITER MLA API work + # -- MTP support is not added yet. + qo_indptr: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self): + prefill_metadata = super().prefill_metadata + self._cached_prefill_metadata = prefill_metadata + + if prefill_metadata is not None: + prefill_metadata.paged_kv_indptr = self.paged_kv_indptr + prefill_metadata.paged_kv_indices = self.paged_kv_indices + prefill_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + prefill_metadata.block_table_bound = self.block_table_bound + prefill_metadata.qo_indptr = self.qo_indptr + + # update the cache + self._cached_prefill_metadata = self.__class__( + **prefill_metadata.__dict__) + + return self._cached_prefill_metadata + + @property + def decode_metadata(self): + decode_metadata = super().decode_metadata + + self._cached_decode_metadata = decode_metadata + + if decode_metadata is not None: + decode_metadata.paged_kv_indptr = self.paged_kv_indptr + decode_metadata.paged_kv_indices = self.paged_kv_indices + decode_metadata\ + .paged_kv_last_page_lens = self.paged_kv_last_page_lens + decode_metadata.block_table_bound = self.block_table_bound + decode_metadata.qo_indptr = self.qo_indptr + + # update the cache + self._cached_decode_metadata = self.__class__( + **decode_metadata.__dict__) + + return self._cached_decode_metadata + + def _ops_advance_step(self, num_seqs: int, num_queries: int, + block_size: int, input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor) -> None: + + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_lens=self.paged_kv_last_page_lens, + block_table_bound=self.block_table_bound) + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + super().__init__(input_builder) + assert self.block_size == 1, "AITER MLA requires only block size 1." + + def prepare(self): + super().prepare() + self.paged_kv_indices: list[int] = [] + self.paged_kv_indptr: list[int] = [0] + self.paged_kv_last_page_lens: list[int] = [] + self.total_blocks = 0 + self.qo_indptr: list[int] = [0] + + def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, + prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + if is_profile_run: + return + + # Update paged_kv_* tensors only for non-profile run + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + self.qo_indptr.append(self.qo_indptr[-1] + 1) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_lens.append(last_page_len) + + def build(self, seq_lens: list[int], query_lens: list[int], + cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: + metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + if use_captured_graph: + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + last_qo_indptr = self.qo_indptr[-1] + self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) + + # For current version of AITER MLA + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device=device, + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device=device, + dtype=torch.int) + paged_kv_last_page_lens_tensor = torch.tensor( + self.paged_kv_last_page_lens, device=device, dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device=device, + dtype=torch.int) + + qo_indptr = torch.tensor(self.qo_indptr, + device=device, + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_lens_tensor = None + block_table_bound_tensor = None + qo_indptr = None + + metadata.paged_kv_indptr = paged_kv_indptr_tensor + metadata.paged_kv_indices = paged_kv_indices_tensor + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor + metadata.block_table_bound = block_table_bound_tensor + metadata.qo_indptr = qo_indptr + + return metadata + + +class AiterMLAState(MLACommonState[AiterMLAMetadata]): + + @contextmanager + def graph_capture(self, max_batch_size: int): + kv_indices, kv_indptr, last_page_lens, qo_indptr = \ + get_aiter_mla_metadata( + max_batch_size=max_batch_size, + block_size=self.runner.block_size, + max_block_per_batch=\ + self.runner.get_max_block_per_batch(), + device=self.runner.device) + self._paged_kv_indices_tensor = kv_indices + self._paged_kv_indptr_tensor = kv_indptr + self._paged_kv_last_page_lens_tensor = last_page_lens + self._qo_indptr_tensor = qo_indptr + + with super().graph_capture(max_batch_size): + yield + + del self._paged_kv_indices_tensor + del self._paged_kv_indptr_tensor + del self._paged_kv_last_page_lens_tensor + del self._qo_indptr_tensor + + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: + + metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + + paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] + paged_kv_indices = self._paged_kv_indices_tensor + paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: + batch_size] + qo_indptr = self._qo_indptr_tensor[:batch_size + 1] + + metadata.paged_kv_indptr = paged_kv_indptr + metadata.paged_kv_indices = paged_kv_indices + metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + metadata.qo_indptr = qo_indptr + + return metadata + + def get_graph_input_buffers(self, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) + input_buffers[ + 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr + input_buffers[ + "paged_kv_indices"] = attn_metadata.\ + decode_metadata.paged_kv_indices + input_buffers[ + "paged_kv_last_page_lens"] = attn_metadata.\ + decode_metadata.paged_kv_last_page_lens + input_buffers['qo_indptr'] = attn_metadata.qo_indptr + + return input_buffers + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata: AiterMLAMetadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + + num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ + 0] + input_buffers["paged_kv_indptr"].copy_( + attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) + input_buffers["paged_kv_indices"][:num_total_blocks].copy_( + attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) + input_buffers["paged_kv_last_page_lens"].copy_( + attn_metadata.decode_metadata.paged_kv_last_page_lens, + non_blocking=True) + input_buffers["qo_indptr"].copy_( + attn_metadata.decode_metadata.qo_indptr, non_blocking=True) + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + softmax_scale: float, return_softmax_lse: bool, + **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: + output = self.flash_attn_varlen_func( + q, + k, + v, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.empty(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) + + return self._v_up_proj(o) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000..e21f44c --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,1096 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer ROCm GPUs.""" +import itertools +from dataclasses import dataclass +from functools import cache +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +import triton +from triton.compiler.compiler import triton_key + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.platforms.rocm import use_rocm_custom_paged_attention +from vllm.utils import SUPPORT_TC, gpuname + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) +_PARTITION_SIZE_ROCM = 256 + + +@cache +def is_rocm_aiter_paged_attn_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ + and envs.VLLM_ROCM_USE_AITER \ + + +@cache +def _get_paged_attn_module() -> PagedAttention: + """ + Initializes the appropriate PagedAttention module from `attention/ops`, + which is used as helper function + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. + + The choice of attention module depends on whether + AITER paged attention is enabled: + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. + - Otherwise, it defaults to using the original `PagedAttention`. + """ + if is_rocm_aiter_paged_attn_enabled(): + # Import AITERPagedAttention only when the flag is enabled + from vllm.attention.ops.rocm_aiter_paged_attn import ( + AITERPagedAttention) + return AITERPagedAttention() + return PagedAttention() + + +class ROCmFlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ROCM_FLASH" + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + paged_attn = _get_paged_attn_module() + return paged_attn.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + paged_attn = _get_paged_attn_module() + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + paged_attn = _get_paged_attn_module() + paged_attn.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + tree_attention_masks_tensor: Optional[torch.Tensor] = None + block_tables_list: Optional[List[int]] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables, + tree_attention_masks_tensor=self.tree_attention_masks_tensor, + block_tables_list=self.block_tables_list) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables, + tree_attention_masks_tensor=self.tree_attention_masks_tensor, + block_tables_list=self.block_tables_list) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + +def _make_alibi_bias(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: + attn_biases = [] + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + Encoder-only attn -> select prefill sequence lengths with + bidirectional attention + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention, encoder-only + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + * Causal masking flag + ''' + + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + + elif attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, we use the prefill sequence lengths + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + # Encoder-only models typically use bidirectional attention + causal_mask = False + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.seq_lens)), + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if use_irope: + logger.warning_once( + "Using irope in ROCm Flash Attention is not supported yet, it " + "will fail back to global attention for long context.") + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.paged_attn_module = _get_paged_attn_module() + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( + ) + + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + if SUPPORT_TC: + self.use_naive_attn = False + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + " logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.triton_attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + if self.sliding_window != (-1, -1): + logger.warning("ROCm Triton FA does not currently support " + "sliding window attention. If using half " + "precision, please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + else: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if not current_platform.has_device_capability(90): + self.use_naive_attn = True + else: + if SUPPORT_TC: + try: + from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func # , vllm_flash_attn_with_kvcache # noqa: F401 + self.fa_attn_func = flash_attn_varlen_func + self.fa_prefix_attn_func = vllm_flash_attn_varlen_func + # self.fa_decode_attn_func = vllm_flash_attn_with_kvcache + + logger.debug("Using CUTLASS FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + else: + self.use_naive_attn = True + envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN = True + + if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support " + "attention logits soft capping.") + + self.sdpa_attn_func = _sdpa_attention + logger.debug("Using naive (SDPA) attention in ROCmBackend") + + self.aiter_kv_scales_initialized = False + self.force_fp8_attention = ( + get_current_vllm_config() is not None + and get_current_vllm_config().model_config.override_attention_dtype + == "fp8") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, + group_shape: tuple[int, int]): + if self.use_triton_flash_attn: + return dtype == current_platform.fp8_dtype( + ) and static and group_shape == (-1, -1) # per-tensor + + # Only supported in the Triton backend + return False + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + * ENCODER_ONLY: bidirectional attention with no KV caching; + use prefill sequence attributes + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_scale is not None and not self.use_triton_flash_attn: + raise NotImplementedError( + "fused output quantization only supported for Triton" + " implementation in ROCMFlashAttentionImpl for now") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + paged_attn = self.paged_attn_module + + # Reshaping kv tensors is required for AITER paged attention kernel + # because it works on a different tensor shape, + # when the size of one element is one byte (int8/fp8 dtypes). + # This reshaping is only required on the first forward call + # and the kv cache must not be empty. + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 + and not self.aiter_kv_scales_initialized + and kv_cache.shape != torch.Size([0])): + num_blocks = kv_cache.shape[1] + block_size = kv_cache.shape[2] // (self.num_kv_heads * + self.head_size) + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), + dtype=torch.float32, + device=kv_cache.device) + self.aiter_kv_scales_initialized = True + k_scale.fill_(layer._k_scale.item()) + v_scale.fill_(layer._v_scale.item()) + layer._k_scale = k_scale + layer._v_scale = v_scale + + # Only update KV cache for decoder self-attention + # and encoder-decoder cross-attention + if self.attn_type not in [ + AttentionType.ENCODER, AttentionType.ENCODER_ONLY + ] and kv_cache.numel() > 0: + key_cache, value_cache = paged_attn.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + paged_attn.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if self.attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + elif self.attn_type == AttentionType.ENCODER_ONLY: + # For encoder-only models, all tokens are processed in one go + num_prefill_tokens = query.shape[0] + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + + # For encoder-only and encoder models, + # we process all tokens at once + # For decoder and encoder-decoder, + # we may need to limit key/value to prefill tokens + if key is not None and value is not None \ + and self.attn_type not in [AttentionType.ENCODER_DECODER, + AttentionType.ENCODER_ONLY]: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + # normal attention and DECODER + if self.attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, self.attn_type) + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + attn_masks = None + if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + seq_lens, + make_attn_mask=causal_mask) # type: ignore + + use_fp8_scales = (layer._q_scale and layer._k_scale + and layer._v_scale and layer._prob_scale + and (self.kv_cache_dtype == "fp8" + or self.force_fp8_attention)) + + full_scales = ( + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), + layer._prob_scale.item()) if use_fp8_scales else None + self.triton_attn_func( + query, + key, + value, + output[:num_prefill_tokens], + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, + self.scale, + attn_masks[0][None] + if attn_masks is not None else None, + full_scales, + output_scale, + ) + elif self.use_naive_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=causal_mask) # type: ignore + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention + self.sdpa_attn_func( + query, + key, + value, + output[:num_prefill_tokens], + query_seq_start_loc, + num_prefill_tokens, + self.num_heads, + self.head_size, + self.scale, + attn_masks, + ) + else: + # upstream FA does not support an output arg, copy + output[:num_prefill_tokens] = self.fa_attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, + softmax_scale=self.scale, + causal=causal_mask, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + else: + # prefix-enabled attention - + # not applicable for encoder-only models + if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: + version_key = triton_key() + if self.attn_type != AttentionType.ENCODER_ONLY: + output[:num_prefill_tokens] = paged_attn.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + layer._k_scale, + layer._v_scale, + ) + else: + assert self.attn_type != AttentionType.ENCODER_ONLY, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + assert prefill_meta.query_start_loc is not None + max_seq_len = max(prefill_meta.seq_lens) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, + key.shape[1]) + ''' + k_cache + triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] ---> + cutlass: num_blocks x page_block_size x num_heads_k x head_size i + ''' + output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + seqused_k=prefill_meta.seq_lens_tensor, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=self.logits_soft_cap, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + # Skip decode phase for encoder-only models + if (decode_meta := attn_metadata.decode_metadata) and ( + self.attn_type != AttentionType.ENCODER_ONLY): + # Decoding run. + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len, self.sliding_window, + self.kv_cache_dtype, self.alibi_slopes) + + if use_custom: + max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type + != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) + assert max_seq_len is not None + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=query.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + query_start_loc = None + ops.paged_attention_rocm( + output[num_prefill_tokens:], + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + query_start_loc, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + output_scale, + ) + else: + # PagedAttention does not support fused quant, manually quantize + if output_scale is None: + out_pa = output[num_prefill_tokens:] + else: + out_pa = torch.empty_like(output[num_prefill_tokens:], + dtype=query.dtype) + + tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor + if envs.VLLM_USE_FLASH_ATTN_PA: + from flash_attn import vllm_flash_attn_with_kvcache + if envs.VLLM_USE_PA_PRINT_PARAM: + print("PA SIZE:") + print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}") + print(f"block_size= {block_size}, cache_seqlens.shape = {decode_meta.seq_lens_tensor.shape}, block_tables.shape = {decode_meta.block_tables.shape}") + print(f"softmax_scale = {self.scale:.3f}, window_size = {self.sliding_window}, softcap = {self.logits_soft_cap}, alibi_slopes = {self.alibi_slopes}") + + # output[num_prefill_tokens:] = self.fa_decode_attn_func( + output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + cache_seqlens=decode_meta.seq_lens_tensor, + block_table=decode_meta.block_tables, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + alibi_slopes=self.alibi_slopes, + return_softmax_lse=False, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + kv_cache_dtype=self.kv_cache_dtype, + ).squeeze(1) + else: + out_pa[:] = paged_attn.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + attn_masks=tree_attention_masks_tensor, + attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0 + ) + + # Manually perform quantization + if output_scale is not None: + out_uq = out_pa.view(-1, self.num_heads * self.head_size) + out_q = output.view(-1, self.num_heads * self.head_size) + ops.scaled_fp8_quant(out_uq, + output_scale, + output=out_q[num_prefill_tokens:]) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _sdpa_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + seq_lens: torch.Tensor, + num_tokens: int, + num_heads: int, + head_size: int, + scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, +) -> torch.Tensor: + start = 0 + assert output.shape == (num_tokens, num_heads, head_size) + assert output.dtype == query.dtype + assert output.device == query.device + + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=attn_masks is None, + attn_mask=attn_masks[i] if attn_masks else None, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000..af5fe81 --- /dev/null +++ b/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,707 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +# yapf: enable +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.logger import init_logger +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder + +logger = init_logger(__name__) + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "TORCH_SDPA" + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError("Swap is not supported in TorchSDPABackend.") + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + prefill_query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None + + # For V1 logits index only + query_start_loc: Optional[torch.Tensor] = None + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + self.encoder_attn_bias: Optional[List[torch.Tensor]] = None + self.cross_attn_bias: Optional[List[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_prefill_tokens == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def get_seq_lens( + self, + attn_type: str, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: str, + ) -> Optional[List[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: List[torch.Tensor], + attn_type: str, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: str, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_builder = input_builder + + def prepare(self): + self.input_data = self.input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + prefill_query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + ) + + return attn_metadata + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") + if use_irope: + logger.warning_once( + "Using irope in Torch SPDA is not supported yet, it will fall" + " back to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: + raise NotImplementedError( + "Torch SDPA backend FP8 KV cache requires " + "intel_extension_for_pytorch support.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TorchSDPABackendImpl") + + # For warming-up + if attn_metadata is None: + return query + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + if prefill_meta := attn_metadata.prefill_metadata: + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + assert attn_metadata.seq_lens is not None + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.prefill_query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: str = AttentionType.DECODER, + ) -> None: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/tree_decoding_utils.py b/vllm/attention/backends/tree_decoding_utils.py new file mode 100644 index 0000000..9938df6 --- /dev/null +++ b/vllm/attention/backends/tree_decoding_utils.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional +import torch + +from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl +from vllm import _custom_ops as ops +from vllm.attention.ops.paged_attn import PagedAttention + +def move_cache( + backend, + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + kv_cache_dtype: str, + num_kv_heads: int, + head_size: int, + ) -> None: + if backend.get_name() == "rocm-flash-attn" or \ + backend.get_name() == "xformers": + + key_caches = [] + value_caches = [] + + num_layers = len(kv_caches) + token_num = src_to_dists.shape[0] + + tmp_store_kv = torch.empty( + (2, num_layers, token_num, num_kv_heads, head_size), + dtype=kv_caches[0].dtype, device=kv_caches[0].device) + keys = tmp_store_kv[0].contiguous() + values = tmp_store_kv[1].contiguous() + + for kv_cache in kv_caches: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, num_kv_heads, head_size) + key_caches.append(key_cache) + value_caches.append(value_cache) + + ops.read_cache( + keys, + values, + key_caches, + value_caches, + src_to_dists[:, 0].contiguous(), + kv_cache_dtype + ) + + ops.write_cache_multi_layers( + keys, + values, + key_caches, + value_caches, + src_to_dists[:, 1].contiguous(), + kv_cache_dtype + ) + else: + raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!") \ No newline at end of file diff --git a/vllm/attention/backends/triton_config.py b/vllm/attention/backends/triton_config.py new file mode 100644 index 0000000..086c967 --- /dev/null +++ b/vllm/attention/backends/triton_config.py @@ -0,0 +1,184 @@ +import functools +import json +import torch +import os +from enum import Enum +from typing import Any, Dict, Optional, Tuple +import bisect +from vllm.logger import init_logger +logger = init_logger(__name__) + +class KERNLE_KINDS(Enum): + v1_2stages = 0 + v1_2stages_tc = 1 + v2 = 2 + v2_tc = 3 + TOTAL_KIND = 4 + +class BestConfig(): + def __init__(self): + self.batch_size = 0 + self.seq_len = 0 + self.kernel_kind = KERNLE_KINDS.TOTAL_KIND + self.BLOCK_N = 0 + self.BLOCK_DIM = 0 + # self.BLOCK_SEQ = 0 + # self.SPLIT_K = 0 + self.num_stages = 0 + self.num_warps = 0 + self.NUM_KV_SPLITS = 0 + self.BLOCK_N_2 = 0 + self.num_stages_2 = 0 + self.num_warps_2 = 0 + self.best_us = 0 + self.decode_fwd_stage1 = None + self.decode_fwd_stage2 = None + +def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str: + if cache_dtype == "default": + return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json" + + device_name = torch.cuda.get_device_name().replace(" ", "_") + if "K100_AI" in device_name: + return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json" + elif "BW" in device_name: + return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json" + else: + raise ValueError(f"Unsurpport device name: {device_name}") + + +def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]: + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + # logger.info("Using decode attention configuration from %s for attention layer.", config_file_path) + # If a configuration has been found, return it + return json.load(f) + else: + logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path) + + json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default") + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path) + # If a configuration has been found, return it + return json.load(f) + else: + raise ValueError("Please surpport default config can match 16 1 576 512") + + # If no optimized configuration is available, we will use the default + # configuration + return None + + +def get_config_map(attention_configs): + ret_map = {} + for bs in attention_configs.keys(): + int_bs = int(bs) + seq_map = {} + seq_configs = attention_configs[bs] + ret_map[int_bs] = seq_map + for seq_len in seq_configs.keys(): + int_seq_len = int(seq_len) + kind_config = seq_configs[seq_len] + configs = BestConfig() + # configs.batch_size = int_bs + # configs.seq_len = int_seq_len + configs.best_us = kind_config['best_us'] + seq_map[int_seq_len] = configs + if kind_config['kernel_kind'] == 'v1_2stages': + best_config = kind_config['best_config'] + stage1 = best_config['stage1'] + stage2 = best_config['stage2'] + configs.kernel_kind = KERNLE_KINDS.v1_2stages + # configs.SPLIT_K = stage1['SPLIT_K'] + configs.BLOCK_N = stage1['BLOCK_N'] + configs.num_stages = stage1['num_stages'] + configs.num_warps = stage1['num_warps'] + configs.BLOCK_N_2 = stage2['BLOCK_N'] + configs.num_stages_2 = stage2['num_stages'] + configs.num_warps_2 = stage2['num_warps'] + elif kind_config['kernel_kind'] == 'v1_2stages_tc': + best_config = kind_config['best_config'] + stage1 = best_config['stage1'] + stage2 = best_config['stage2'] + configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc + # configs.SPLIT_K = stage1['SPLIT_K'] + configs.BLOCK_N = stage1['BLOCK_N'] + configs.num_stages = stage1['num_stages'] + configs.num_warps = stage1['num_warps'] + configs.BLOCK_N_2 = stage2['BLOCK_N'] + configs.num_stages_2 = stage2['num_stages'] + configs.num_warps_2 = stage2['num_warps'] + elif kind_config['kernel_kind'] == 'v2': + best_config = kind_config['best_config'] + stage1 = best_config['stage1'] + stage2 = best_config['stage2'] + configs.kernel_kind = KERNLE_KINDS.v2 + # if 'BLOCK_SEQ' in stage1: + # configs.BLOCK_SEQ = stage1['BLOCK_SEQ'] + # else: + # configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS'] + configs.BLOCK_N = stage1['BLOCK_N'] + configs.num_stages = stage1['num_stages'] + configs.num_warps = stage1['num_warps'] + configs.num_stages_2 = stage2['num_stages'] + configs.num_warps_2 = stage2['num_warps'] + elif kind_config['kernel_kind'] == 'v2_tc': + best_config = kind_config['best_config'] + stage1 = best_config['stage1'] + stage2 = best_config['stage2'] + configs.kernel_kind = KERNLE_KINDS.v2_tc + # if 'BLOCK_SEQ' in stage1: + # configs.BLOCK_SEQ = stage1['BLOCK_SEQ'] + # else: + # configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS'] + configs.BLOCK_N = stage1['BLOCK_N'] + configs.BLOCK_DIM = stage1['BLOCK_DIM'] + configs.num_stages = stage1['num_stages'] + configs.num_warps = stage1['num_warps'] + configs.num_stages_2 = stage2['num_stages'] + configs.num_warps_2 = stage2['num_warps'] + return ret_map + + +@functools.lru_cache +def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]: + attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype) + return get_config_map(attention_configs) + + +def get_closest_key(dic_keys, target_key): + keys = list(dic_keys) + idx = bisect.bisect_left(keys, target_key) + if idx == 0: + return keys[0] + if idx == len(keys): + return keys[-1] + left_key = keys[idx - 1] + right_key = keys[idx] + if target_key - left_key <= right_key - target_key: + return left_key + else: + return right_key + +def get_nearest_config(bs_key, mean_kv_seqlen_key, config): + closest_bs_key = get_closest_key(config.keys(), bs_key) + closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key) + return config[closest_bs_key][closest_mean_kv_seqlen_key] + +def get_config(bs_key, mean_kv_seqlen_key, config): + if bs_key in config and mean_kv_seqlen_key in config[bs_key]: + return config[bs_key][mean_kv_seqlen_key] + else: + raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db") \ No newline at end of file diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py new file mode 100644 index 0000000..3cf2d21 --- /dev/null +++ b/vllm/attention/backends/triton_mla.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +from typing import Any, Dict, List, Optional, Type +from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json + +import torch + +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) +from vllm.attention.ops.triton_decode_attention import decode_attention_fwd +import vllm.envs as envs + +from vllm.logger import init_logger +logger = init_logger(__name__) + + +class TritonMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "TRITON_MLA" + + @staticmethod + def get_impl_cls() -> Type["TritonMLAImpl"]: + return TritonMLAImpl + + +class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "TritonMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl") + + if envs.VLLM_USE_TRITON_OPT_MLA: + self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "TritonMLA with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + num_kv_splits = 4 # TODO: heuristic + + # TODO(lucas) Allocate ahead of time + attn_logits = torch.empty( + ( + B, + self.num_heads, + num_kv_splits, + # NOTE(lucas) idk why the +1 is here but sglang has it so we + # just mirror that + self.kv_lora_rank + 1, + ), + dtype=torch.float32, + device=q.device, + ) + + # Add a head dim of 1 + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + PAGE_SIZE = kv_c_and_k_pe_cache.size(1) + + # TODO + max_seq_len = torch.max(decode_meta.seq_lens_tensor).item() + if os.environ.get('PA_MATCH_USE_MEAN_SEQ') == '1': + match_seq_len = int((decode_meta.seq_lens_tensor.sum()/ max(1, B)).item()) + else: + match_seq_len = max_seq_len + + if envs.VLLM_USE_TRITON_OPT_MLA: + best_config = self.attn_configs[min(self.attn_configs.keys(), key=lambda x: abs(int(x) - match_seq_len))] + + # Run MQA + decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, attn_logits, + num_kv_splits, self.scale, best_config, PAGE_SIZE) + + return self._v_up_proj(o) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000..2f5414c --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,635 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention backend utils""" +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import accumulate +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) + +import numpy as np +import torch + +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) +from vllm.attention.backends.abstract import AttentionType +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + +# Error string(s) for encoder/decoder +# unsupported attention scenarios +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +# Switch to numpy implementation of compute_slot_mapping +# if we have at least this many elements. Could be tuned further. +_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def _compute_slot_mapping_python(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + for i in range(range_start, range_end): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +def _compute_slot_mapping_numpy(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + block_table_array = np.array(block_table) + idx = np.arange(range_start, range_end) + block_offset = idx % block_size + idx //= block_size + seq_slot_mapping_array = block_table_array[idx] + seq_slot_mapping_array *= block_size + seq_slot_mapping_array += block_offset + slot_mapping.extend(seq_slot_mapping_array) + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + padding_mask_len = max(0, start_idx - context_len) + slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) + + range_start = max(start_idx, context_len) + range_end = seq_len + numel = range_end - range_start + block_table = block_tables[seq_id] + + # numpy implementation will be faster than python if we have + # many elements, otherwise it will be slower. + if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: + _compute_slot_mapping_python(slot_mapping, block_table, range_start, + range_end, block_size) + else: + _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, + range_end, block_size) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + # block_tables = torch.from_numpy(input_block_tables).to( + # device, non_blocking=True) + block_tables = torch.from_numpy(input_block_tables).pin_memory().to( + device, non_blocking=True) + + else: + has_empty: bool = any(len(bt) == 0 for bt in self.block_tables) + has_non_empty = any(len(bt) > 0 for bt in self.block_tables) + max_block_length = 0 + if has_empty and has_non_empty: + for inter_data in self.input_builder.inter_data_list: + block_tables = inter_data.block_tables + if block_tables: + for seq_id in inter_data.seq_ids: + if seq_id in block_tables: + block_table = block_tables[seq_id] + max_block_length = max(max_block_length, len(block_table)) + if max_block_length >0: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + max_len=max_block_length, + ) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + block_tables_list=self.block_tables + ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + + self._is_graph_capturing = True + + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + + yield + + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ + f"got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + + return attn_metadata + + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ + f"got '{self.runner.attn_backend.get_name()}'" + self._add_additional_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + + def begin_forward(self, model_input) -> None: + return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 + + def _add_additional_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: str, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (AttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) + + +@dataclass +class MLADims: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py new file mode 100644 index 0000000..84068f5 --- /dev/null +++ b/vllm/attention/backends/xformers.py @@ -0,0 +1,818 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with xFormers and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMaskWithTensorBias) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class XFormersBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "XFORMERS" + + @staticmethod + def get_impl_cls() -> Type["XFormersImpl"]: + return XFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for XFormersbackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # FIXME: It is for flash attn. + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + + # Self-attention prefill/decode metadata cache + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + tree_attention_masks_tensor: Optional[torch.Tensor] = None + block_tables_list: Optional[List[int]] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables, + tree_attention_masks_tensor=self.tree_attention_masks_tensor, + block_tables_list=self.block_tables_list) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables, + tree_attention_masks_tensor=self.tree_attention_masks_tensor, + block_tables_list=self.block_tables_list) + + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] + return self._cached_decode_metadata + + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: str, +) -> Optional[AttentionBias]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: str, +) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + logger.warning_once("XFormers does not support logits soft cap. " + "Outputs may be slightly off.") + if use_irope: + logger.warning_once( + "Using irope in XFormers is not supported yet, it will fall" + " back to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: "XFormersMetadata", + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). + Used for encoder branch of encoder-decoder models. + * ENCODER_ONLY: no kv_caching, uses the normal attention + attributes (seq_lens/seq_lens_tensor/max_seq_len). + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for XFormersImpl") + + attn_type = self.attn_type + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + PagedAttention.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + if key is not None and value is not None: + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # normal attention. + # block tables are empty if the prompt does not have a cached + # prefix. + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta, attn_type=attn_type) + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out + else: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have prefix attention.") + + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + + # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + out = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window, + layer._k_scale, + layer._v_scale, + ) + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + + tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor + + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + attn_masks=tree_attention_masks_tensor, + attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0 + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformers_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: str = AttentionType.DECODER, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + + Args: + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + """ + + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias is None: + if self.alibi_slopes is None: + + # Cross attention block of decoder branch of encoder-decoder + # model uses seq_lens for dec / encoder_seq_lens for enc + if (attn_type == AttentionType.ENCODER_DECODER): + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + + # Cross-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens, + device=query.device) + + # Encoder branch of encoder-decoder model uses + # attn_metadata.encoder_seq_lens + elif attn_type == AttentionType.ENCODER: + + assert attn_metadata.encoder_seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens, device=query.device) + + # Self-attention block of encoder-only model just + # uses the seq_lens directly. + elif attn_type == AttentionType.ENCODER_ONLY: + assert attn_metadata.seq_lens is not None + + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, device=query.device) + + # Self-attention block of decoder branch just + # uses the seq_lens directly + elif attn_type == AttentionType.DECODER: + assert attn_metadata.seq_lens is not None + + # Decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens, device=query.device) + else: + raise ValueError("Unknown AttentionType: %s", attn_type) + + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + attn_bias = [attn_bias] + else: + assert attn_type == AttentionType.DECODER + assert attn_metadata.seq_lens is not None + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) + + _set_attn_bias(attn_metadata, attn_bias, attn_type) + + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale) + return out.view_as(original_query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None + output = torch.empty_like(original_query) + start = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=self.scale) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.view_as(original_query[start:end])) + start += seq_len + return output + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[AttentionBias]: + attn_biases: List[AttentionBias] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py new file mode 100644 index 0000000..920615e --- /dev/null +++ b/vllm/attention/layer.py @@ -0,0 +1,481 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.attention import AttentionType +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import _Backend, current_platform +from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.utils import validate_kv_sharing_target + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + **extra_impl_args, + ) -> None: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + """ + super().__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16 + is_attention_free = False + calculate_kv_scales = False + if num_kv_heads is None: + num_kv_heads = num_heads + assert num_heads % num_kv_heads == 0, \ + f"num_heads ({num_heads}) is not " \ + f"divisible by num_kv_heads ({num_kv_heads})" + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + # FlashAttn doesn't support quantizing the kv-cache only + # but requires q to be quantized as well. + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + self.use_mla = use_mla + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.sliding_window = sliding_window + + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None and not isinstance( + quant_method, UnquantizedLinearMethod): + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **extra_impl_args) + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = not current_platform.is_cuda_alike( + ) and not current_platform.is_cpu() + + self.use_output = attn_backend.accept_output_buffer + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.attn_type = attn_type + + if kv_sharing_target_layer_name is not None: + if not envs.VLLM_USE_V1: + raise NotImplementedError( + "Cross-layer KV sharing is not supported in V0.") + + validate_kv_sharing_target( + prefix, + kv_sharing_target_layer_name, + compilation_config.static_forward_context, + ) + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + # For some alternate attention backends like MLA the attention output + # shape does not match the query shape, so we optionally let the model + # definition specify the output tensor shape. + output_shape: Optional[torch.Size] = None, + ) -> torch.Tensor: + """ + The KV cache is stored inside this class and is accessed via + `self.kv_cache`. + + Attention metadata (`attn_metadata`) is set using a context manager in + the model runner's `execute_model` method. It is accessed via forward + context using + `vllm.forward_context.get_forward_context().attn_metadata`. + """ + if self.calculate_kv_scales: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(query, key, value) + if self.use_output: + output_shape = (output_shape + if output_shape is not None else query.shape) + output = torch.zeros(output_shape, + dtype=query.dtype, + device=query.device) + hidden_size = output_shape[-1] + # We skip reshaping query, key and value tensors for the MLA + # backend since these tensors have different semantics and are + # processed differently. + if not self.use_mla: + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + if self.use_direct_call: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output) + else: + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, self.layer_name) + return output.view(-1, hidden_size) + else: + if self.use_direct_call: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, + self_kv_cache, attn_metadata) + else: + return torch.ops.vllm.unified_attention( + query, key, value, self.layer_name) + + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + # We only calculate the scales once + self.calculate_kv_scales = False + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): + self.impl.process_weights_after_loading(act_dtype) + + +class MultiHeadAttention(nn.Module): + """Multi-headed attention without any cache, used for ViT.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0, \ + f"num_heads ({self.num_heads}) is not " \ + f"divisible by num_kv_heads ({self.num_kv_heads})" + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype=None, + block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16, + is_attention_free=False) + backend = backend_name_to_enum(attn_backend.get_name()) + if current_platform.is_rocm(): + # currently, only torch_sdpa is supported on rocm + self.attn_backend = _Backend.TORCH_SDPA + else: + if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, + _Backend.FLEX_ATTENTION): + backend = _Backend.XFORMERS + + self.attn_backend = backend if backend in { + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 + } else _Backend.TORCH_SDPA + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """Input shape: batch_size x seq_len x hidden_size""" + # TODO(Isotr0py): Use existing backend implementations and support FA3 + bsz, q_len, _ = query.size() + kv_len = key.size(1) + + query = query.view(bsz, q_len, self.num_heads, self.head_size) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query, + key, + value, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, + key, + value, + scale=self.scale) + out = out.transpose(1, 2) + elif self.attn_backend == _Backend.PALLAS_VLLM_V1: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + from torch_xla.experimental.custom_kernel import flash_attention + out = flash_attention(query, key, value, sm_scale=self.scale) + out = out.transpose(1, 2) + + return out.reshape(bsz, q_len, -1) + + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + assert isinstance(attn_metadata, dict) + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + assert isinstance(attn_metadata, dict) + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata[layer_name]) + + +def unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + wait_for_kv_layer_from_connector(layer_name) + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output + + +def unified_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_attention", + op_func=unified_attention, + mutates_args=[], + fake_impl=unified_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, +) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale) + + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + layer_name: str, + output_scale: Optional[torch.Tensor] = None, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/attention/ops/__init__.py b/vllm/attention/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000..05fa9d1 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,433 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import tl, triton + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD, + other=0.0) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + other=0.0, + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD, + other=0.0) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + other=0.0, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000..c6f6cc2 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch + +from vllm.platforms import current_platform + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = current_platform.is_rocm() or \ + current_platform.is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if current_platform.is_cuda_alike() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v]) + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000..445720c --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import numpy as np +import torch + +from vllm.triton_utils import triton + + +class csr_matrix: + """Simple implementation of CSR matrix conversion without scipy. + This replaced scipy.sparse.csr_matrix() previously used.""" + + def __init__(self, input_array): + if not isinstance(input_array, np.ndarray): + raise ValueError("Input must be a NumPy array") + + self.shape = input_array.shape + rows, cols = self.shape + data = [] + indices = [] + indptr = [0] + + for i in range(rows): + for j in range(cols): + if input_array[i, j]: + data.append(input_array[i, j]) + indices.append(j) + indptr.append(len(indices)) + + self.data = np.array(data) + self.indices = np.array(indices) + self.indptr = np.array(indptr) + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py new file mode 100644 index 0000000..4f83934 --- /dev/null +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.platforms.rocm import use_rocm_custom_paged_attention +from vllm.triton_utils import tl, triton + +from .prefix_prefill import context_attention_fwd + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def kernel_paged_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] +): + seq_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + if filter_by_query_len: + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + + 1) + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + if cur_batch_query_len > 1: + return + else: + cur_batch_in_all_start_index = seq_idx + + query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( + 0, num_queries_per_kv_padded) + + query_offset = (cur_batch_in_all_start_index * query_stride_0 + + query_head_idx[:, None] * query_stride_1) + + head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv + head_mask = head_mask & (query_head_idx < num_query_heads) + + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # Q : (num_queries_per_kv, HEAD_SIZE,) + Q = tl.load( + query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + mask=dim_mask[None, :] & head_mask[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) + L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], + dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, + mask=head_mask, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + boundary = tl.full([BLOCK_SIZE], seq_len, dtype=tl.int32) + seq_mask = seq_offset[None, :] < boundary + + # S : (num_queries_per_kv, BLOCK_SIZE,) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, + float("-inf")).to(tl.float32) + S += scale * tl.dot(Q, K) + + context_len = seq_len - 1 + + if SLIDING_WINDOW > 0: + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, + -10000) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (num_queries_per_kv,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + + # P : (num_queries_per_kv, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (num_queries_per_kv,) + l_j = tl.sum(P, axis=1) + + # alpha : (num_queries_per_kv, ) + alpha = tl.exp(M - m_j) + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (num_queries_per_kv, BLOCK_SIZE,) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1) + + tl.store( + output_ptr + output_offset[:, None] + + tl.arange(0, HEAD_SIZE_PADDED)[None, :], + acc, + mask=dim_mask[None, :] & head_mask[:, None], + ) + + +def chunked_prefill_paged_decode( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_table, + query_start_loc, + seq_lens, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, +): + + if sm_scale is None: + sm_scale = 1.0 / (query.shape[1]**0.5) + + use_alibi_slopes = alibi_slopes is not None + + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if max_query_len > 1: + context_attention_fwd( + q=query, + k=key, + v=value, + o=output, + kv_cache_dtype=kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=block_table, + b_start_loc=query_start_loc, + b_seq_len=seq_lens, + max_seq_len=max_seq_len, + max_input_len=max_query_len, + k_scale=k_scale, + v_scale=v_scale, + alibi_slopes=alibi_slopes, + sliding_window=sliding_window, + sm_scale=sm_scale, + skip_decode=True, + ) + + block_size = value_cache.shape[3] + num_seqs = len(seq_lens) + num_query_heads = query.shape[1] + num_kv_heads = key.shape[1] + num_queries_per_kv = query.shape[1] // key.shape[1] + head_size = query.shape[2] + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + key_cache = key_cache.view(target_dtype) + value_cache = value_cache.view(target_dtype) + + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), + 16) + + use_custom = use_rocm_custom_paged_attention(query.dtype, head_size, + block_size, + num_queries_per_kv, + max_seq_len, sliding_window, + kv_cache_dtype, alibi_slopes) + if use_custom: + _PARTITION_SIZE_ROCM = 256 + max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + total_num_seq = block_table.shape[0] + tmp_output = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions, + head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(total_num_seq, num_query_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale=sm_scale, + block_tables=block_table, + seq_lens=seq_lens, + query_start_loc=query_start_loc, + block_size=block_size, + max_seq_len=max_seq_len, + alibi_slopes=alibi_slopes, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + kernel_paged_attention_2d[( + num_seqs, + num_kv_heads, + )]( + output_ptr=output, + query_ptr=query, + key_cache_ptr=key_cache, + value_cache_ptr=value_cache, + block_tables_ptr=block_table, + seq_lens_ptr=seq_lens, + alibi_slopes_ptr=alibi_slopes, + scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + num_queries_per_kv_padded=num_queries_per_kv_padded, + block_table_stride=block_table.stride(0), + query_stride_0=query.stride(0), + query_stride_1=query.stride(1), + output_stride_0=output.stride(0), + output_stride_1=output.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + SLIDING_WINDOW=sliding_window, + x=key_cache.shape[4], + stride_k_cache_0=key_cache.stride(0), + stride_k_cache_1=key_cache.stride(1), + stride_k_cache_2=key_cache.stride(2), + stride_k_cache_3=key_cache.stride(3), + stride_k_cache_4=key_cache.stride(4), + stride_v_cache_0=value_cache.stride(0), + stride_v_cache_1=value_cache.stride(1), + stride_v_cache_2=value_cache.stride(2), + stride_v_cache_3=value_cache.stride(3), + filter_by_query_len=True, + query_start_len_ptr=query_start_loc, + ) diff --git a/vllm/attention/ops/flash_attn_triton_mqa_gqa.py b/vllm/attention/ops/flash_attn_triton_mqa_gqa.py new file mode 100644 index 0000000..70a3edc --- /dev/null +++ b/vllm/attention/ops/flash_attn_triton_mqa_gqa.py @@ -0,0 +1,1308 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import argparse +import pytest +import random +import sys +import torch + +import triton +import triton.language as tl + +torch_dtype:tl.constexpr = torch.float16 + +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype:tl.constexpr = torch.float8_e5m2fnuz + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + dropout_p, return_encoded_softmax = 0.0, False + + def __init__(self, sm_scale=1.0, causal=False, dropout_p=0.0, return_encoded_softmax=False): + self.sm_scale = sm_scale + self.causal = causal + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range (0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i+1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_dropout(dropout_p, return_encoded_softmax): + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + if self.varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias == None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + +@triton.jit +def cdiv_fn(x,y): + return (x + y - 1) // y + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + +@triton.jit +def print_gpu(prefix, val=None): + if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): + if val is not None: + tl.device_print(prefix, val) + else: + tl.device_print(prefix) + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose = False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + +@triton.jit +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + actual_seqlen_k, + actual_seqlen_q, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + alibi_slope, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr +): + # loop over k, v, and update accumulator + for start_n in range (block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None,:] + mask = size_n < boundary_m[:,None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m*BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) + + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax + m_ij = tl.maximum(m_i, tl.max(qk,1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 0, 'PRE_LOAD_V': True}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 16, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + # triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 0, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +# use_cuda_graph=True, +) +@triton.jit +def attn_fwd( + Q, K, V, bias, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + stride_bz, stride_bh, stride_bm, stride_bn, + stride_az, stride_ah, + cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, + alibi_slopes, + HQ: tl.constexpr, HK:tl.constexpr, + ACTUAL_BLOCK_DMODEL:tl.constexpr, + MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_ALIBI: tl.constexpr, + BATCH_SIZE: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, + BLOCK_N + ) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + need_padding = False + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + need_padding = True + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + need_padding = True + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD:tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + if BIAS_TYPE != 0: + b_offset = off_h_q * stride_bh # Note: this might get large enough to overflow on some configs + bias_ptr = tl.make_block_ptr( + base=bias + b_offset, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, bias_ptr, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, seqlen_q, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, alibi_slope, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # This is a > check because mask being 0 blocks the store. + l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0,1)) + +@triton.jit +def _attn_bwd_preprocess( + Out, DO, + Delta, + stride_oz, stride_oh, stride_om, stride_on, + stride_doz, stride_doh, stride_dom, stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, head_dim), + strides=(stride_om, stride_on), + offsets=(off_m, 0), + block_shape=(BLOCK_M, D_HEAD), + order=(1, 0) + ) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr( + base=DO + do_offset, + shape=(seqlen_q, head_dim), + strides=(stride_dom, stride_don), + offsets=(off_m, 0), + block_shape=(BLOCK_M, D_HEAD), + order=(1, 0) + ) + # load + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0,1), padding_option="zero").to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) + +@triton.jit +def _bwd_kernel_dk_dv( + dk, dv, + Q, k, v, sm_scale, alibi_slope, + DO, + M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr( + base=Q, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_m), + block_shape=(BLOCK_DMODEL, BLOCK_M1), + order=(0,1) + ) + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M1, BLOCK_DMODEL), + order=(1,0) + ) + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(QT_block_ptr) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + kqT = tl.dot(k, qT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) + kqT += alibi_block * 1.44269504089 + + pT = tl.math.exp2(kqT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(DO_block_ptr) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + return dk, dv + +@triton.jit +def _bwd_kernel_dq(dq, q, K, V, + do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, + MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr( + base=K, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) + VT_block_ptr = tl.make_block_ptr( + base=V, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_d, stride_tok), + offsets=(0, start_n), + block_shape=(BLOCK_DMODEL, BLOCK_N2), + order=(0, 1) + ) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(KT_block_ptr) + qk = tl.dot(q, kT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) + qk += alibi_block * 1.44269504089 + + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + vT = tl.load(VT_block_ptr) + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ.0. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + return dq + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, + DO, + DQ, DK, DV, + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + USE_ALIBI: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + + if USE_ALIBI: + a_offset = bhid + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + # compute dK and dV for blocks close to the diagonal that need to be masked + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _bwd_kernel_dk_dv( + dk, dv, + Q, k, v, sm_scale, alibi_slope, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True + ) + + # compute dK and dV for blocks that don't need masking further from the diagonal + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + dk, dv = _bwd_kernel_dk_dv( + dk, dv, + Q, k, v, sm_scale, alibi_slope, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False + ) + + DV_block_ptrs = tl.make_block_ptr( + base=DV, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) + tl.store(DV_block_ptrs, dv.to(v.dtype)) + + # Write back dK. + dk *= sm_scale + DK_block_ptrs = tl.make_block_ptr( + base=DK, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1,0) + ) + tl.store(DK_block_ptrs, dk.to(k.dtype)) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + Q_block_ptr = tl.make_block_ptr( + base=Q, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + + DO_block_ptr = tl.make_block_ptr( + base=DO, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, + do, m, D, alibi_slope, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, + do, m, D, alibi_slope, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False + ) + # Write back dQ. + DQ_block_ptr = tl.make_block_ptr( + base=DQ, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_m, 0), + block_shape=(BLOCK_M2, BLOCK_DMODEL), + order=(1, 0) + ) + dq *= LN2 + tl.store(DQ_block_ptr, dq.to(q.dtype)) + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, o, metadata): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert(metadata.bias.numel() < 2 ** 31) + + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + import os + if os.environ.get("FLASH_ATTENTION_PRINT_PARAM", "0") == "1": + print(f"triton flash attention: {q.shape=}, {k.shape=}, {v.shape}, {o.shape=}") + print(f"triton flash attention: {q.stride()=}, {k.stride()=}, {v.stride()=}, {o.stride()=}") + print(f"triton flash attention: {metadata=}") + metadata.check_args(q, k, v, o) + if metadata.varlen: + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = metadata.num_contexts + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + padded_d_model = max(padded_d_model, 16) + + grid = lambda META: ( + triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), + nheads_q, + batch + ) + + # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, dtype=torch.float32) + else: + encoded_softmax = None + + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), + metadata.bias.stride(2), metadata.bias.stride(3)) + else: + bias_strides = (0,0,0,0) + + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + attn_fwd[grid]( + q, k, v, metadata.bias, metadata.sm_scale, M, o, + *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, *alibi_strides, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, + dropout_p=metadata.dropout_p, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + alibi_slopes = metadata.alibi_slopes, + HQ=nheads_q, HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, + VARLEN=metadata.varlen, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if metadata.bias is None else 1, + USE_ALIBI=False if metadata.alibi_slopes is None else True, + ENABLE_DROPOUT=metadata.dropout_p > 0.0, + RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, + BATCH_SIZE= q.shape[0] + ) + if os.environ.get("FLASH_ATTENTION_PRINT_PARAM", "0") == "1": + best_config = attn_fwd.get_best_config() + print(f"{best_config.kwargs=}, {best_config.num_stages=}, {best_config.num_warps=}") + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax + return o if not metadata.return_encoded_softmax else (o, encoded_softmax, ) # S_dmask + + @staticmethod + def backward(ctx, do, *args): + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, M = ctx.saved_tensors + import os + if os.environ.get("TRITON_FLASHATTN_DEBUG", "0") == "1": + print(f"triton flash attention: {q.shape=}, {k.shape=}, {v.shape}, {o.shape=}, {do.shape=}") + print(f"triton flash attention: {q.stride()=}, {k.stride()=}, {v.stride()=}, {o.stride()=}, {do.stride()}") + # assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + assert N_CTX % PRE_BLOCK == 0 + delta = torch.empty_like(M) + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, do, delta, + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + do.stride(0), do.stride(1), do.stride(2), do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: ( + triton.cdiv(N_CTX, META['BLOCK_N1']), + 1, + BATCH * N_HEAD + ) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, ctx.alibi_slopes, do, dq, dk, dv, + M, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + N_HEAD, N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI= False if ctx.alibi_slopes is None else True, + ) + + return dq, dk, dv, None, None + +attention = _attention.apply + +# flash_attn wrapper + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype): + torch.manual_seed(20) + + # Initialize q, k, v + q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + return q, k, v, input_metadata + +def padding_bshd(t): # BSHD + batch, seqlen, nheads, dim = t.shape + t = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads*dim), (0, 32), 'constant', 0)[:,:,:-32].reshape(batch, seqlen, nheads, dim) # pad: nheads*dim+32 + # t = torch.nn.functional.pad(t.reshape(batch, seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:,:,:-32] # pad: dim+32 + return t + +def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, padding_input=False): + if padding_input: + k, v = (padding_bshd(t) for t in (k, v)) + q, k, v = (t.transpose(1, 2) for t in (q, k, v)) + softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5 + input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs) + input_metadata.max_seqlens_q = q.shape[2] + input_metadata.max_seqlens_k = k.shape[2] + return _attention.apply(q, k, v, None, input_metadata).transpose(1, 2) + +def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, padding_input=False): + k, v = kv[:, :, 0], kv[:, :, 1] # batch_size, seqlen, 2, nheads_k, headdim + if padding_input: + k, v = (padding_bshd(t) for t in (k, v)) # pad + q, k, v = (t.transpose(1, 2) for t in (q, k, v)) # trans bshd to bhsd + softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5 + input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs) + input_metadata.max_seqlens_q = q.shape[2] + input_metadata.max_seqlens_k = k.shape[2] + return _attention.apply(q, k, v, None, input_metadata).transpose(1, 2) # trans bhsd to bshd + +def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, padding_input=False): + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + if padding_input: + k, v = (padding_bshd(t) for t in (k, v)) + q, k, v = (t.transpose(1, 2) for t in (q, k, v)) + softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5 + input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs) + input_metadata.max_seqlens_q = q.shape[2] + input_metadata.max_seqlens_k = k.shape[2] + return _attention.apply(q, k, v, None, input_metadata).transpose(1, 2) + +# varlen flash_attn + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal): + torch.manual_seed(20) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) + max_seqlens_q = torch.max(seqlens_q).item() + max_seqlens_k = torch.max(seqlens_k).item() + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + # -1 because the last entry of cu_seqlens_q specifies the end of the last seq + num_ctxs = len(cu_seqlens_q) - 1 + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + input_metadata.max_seqlens_q = max_seqlens_q + input_metadata.max_seqlens_k = max_seqlens_k + if causal: + input_metadata.need_causal() + return q, k, v, input_metadata + +def padding_thd(t): # THD + total_seqlen, nheads, dim = t.shape + t = torch.nn.functional.pad(t.reshape(total_seqlen, nheads*dim), (0, 32), 'constant', 0)[:,:-32].reshape(total_seqlen, nheads, dim) # pad: nheads*dim+32 + # t = torch.nn.functional.pad(t.reshape(total_seqlen, nheads, dim), (0, 32), 'constant', 0)[:,:-32] # pad: dim+32 + return t + +def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlens, dropout_p=0.0, softmax_scale=None, + causal=False, return_attn_probs=False, padding_input=False): + q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] # total_seqlen, 3, nheads, dim + if padding_input: + k, v = (padding_thd(t) for t in (k, v)) # pad + softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5 + input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs) + input_metadata.set_varlen_params(cu_seqlens, cu_seqlens) + input_metadata.max_seqlens_q = max_seqlens + input_metadata.max_seqlens_k = max_seqlens + return _attention.apply(q, k, v, None, input_metadata) + + +def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, padding_input=False): + k, v = kv[:, 0], kv[:, 1] # total_seqlen, 2, nheads, dim + if padding_input: + k, v = (padding_thd(t) for t in (k, v)) + softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5 + input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + input_metadata.max_seqlens_q = max_seqlens_q + input_metadata.max_seqlens_k = max_seqlens_k + return _attention.apply(q, k, v, None, input_metadata) + + +def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, padding_input=False): + if padding_input: + k, v = (padding_thd(t) for t in (k, v)) + softmax_scale = softmax_scale if softmax_scale else q.shape[-1]**-0.5 + input_metadata = MetaData(sm_scale=softmax_scale, causal=causal, dropout_p=dropout_p, return_encoded_softmax=return_attn_probs) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + input_metadata.max_seqlens_q = max_seqlens_q + input_metadata.max_seqlens_k = max_seqlens_k + return _attention.apply(q, k, v, None, input_metadata) + +# legacy interface +def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlens, dropout_p, softmax_scale=None, + causal=False, return_attn_probs=False, padding_input=False): + return flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlens, dropout_p, softmax_scale, + causal, return_attn_probs, padding_input) + +def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, padding_input=False): + return flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p, softmax_scale, causal, return_attn_probs, padding_input) + +def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p, softmax_scale=None, + causal=False, return_attn_probs=False, padding_input=False): + return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlens_q, max_seqlens_k, + dropout_p, softmax_scale, causal, return_attn_probs, padding_input) + + diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py new file mode 100644 index 0000000..dcf99f0 --- /dev/null +++ b/vllm/attention/ops/flashmla.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +if current_platform.is_cuda(): + try: + import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True + except ImportError: + _flashmla_C_AVAILABLE = False +else: + _flashmla_C_AVAILABLE = False + +if current_platform.is_rocm(): + import flash_mla_cuda + _flashmla_C_AVAILABLE = True + +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + if not (current_platform.is_cuda() or current_platform.is_rocm()): + return False, "FlashMLA is supported on CUDA and ROCM devices." + if current_platform.get_device_capability()[0] != 9: + return False, "FlashMLA is only supported on Hopper devices." + if not _flashmla_C_AVAILABLE: + return False, "vllm._flashmla_C is not available, likely was not "\ + "compiled due to insufficient nvcc version or a supported arch "\ + "(only sm90a currently) was not in the list of target arches to "\ + "compile for." + return True, None + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + if current_platform.is_rocm(): + return flash_mla_cuda.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + else: + return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + k_scale = None, + kv_cache_dtype = "auto", +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + if current_platform.is_rocm(): + if kv_cache_dtype == "fp8": + out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + k_scale, + "fp8_e4m3", + ) + return out, softmax_lse + out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + else: + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse + + +# +# TODO: Add fake functions +# +# @register_fake("_flashmla_C::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_flashmla_C::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# \ No newline at end of file diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py new file mode 100644 index 0000000..412dd20 --- /dev/null +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +############################################################################### +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company +############################################################################### + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from vllm_hpu_extension import cache_ops, ops + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class HPUPagedAttentionMetadata: + """Metadata for PagedAttention.""" + block_list: Optional[torch.Tensor] + block_mapping: Optional[torch.Tensor] + block_usage: Optional[torch.Tensor] + block_indices: Optional[torch.Tensor] + block_offsets: Optional[torch.Tensor] + block_groups: Optional[torch.Tensor] + + +class HPUPagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, kv_cache_dtype: str, + is_prompt: bool) -> None: + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, is_prompt) + + @staticmethod + def forward_decode(**kwargs) -> torch.Tensor: + return ops.flat_pa(**kwargs) + + @staticmethod + def swap_blocks( + src_kv_cache: Tuple[torch.Tensor, torch.Tensor], + dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], + src_to_dsts: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts) + + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dsts: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py new file mode 100644 index 0000000..8919754 --- /dev/null +++ b/vllm/attention/ops/ipex_attn.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List, Optional, Tuple + +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813 +except (ImportError, AttributeError): + _use_ipex = False + +import torch + +from vllm import _custom_ops as ops + + +class _PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return 2, num_blocks, block_size * num_kv_heads * head_size + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/attention/ops/merge_attn_states.py b/vllm/attention/ops/merge_attn_states.py new file mode 100644 index 0000000..d052b76 --- /dev/null +++ b/vllm/attention/ops/merge_attn_states.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.platforms import current_platform +from vllm import envs + + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel + # is not support for FP8 dtype, fallback to use Triton kernel. + def supported_dtypes(o: torch.Tensor) -> bool: + return o.dtype in [torch.float32, torch.half, torch.bfloat16] + + # NOTE(DefTruth): Currently, custom merge_attn_states CUDA + # kernel load/store 128b(16 bytes) per memory issue within + # thread. Namely, the headsize(headdim) must be multiple of + # pack_size (float32 -> 4, half/bfloat16 -> 8). + def supported_headdim(o: torch.Tensor) -> bool: + headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + if o.dtype == torch.float32: + return headdim % 4 == 0 + return headdim % 8 == 0 + + if (current_platform.is_cuda() or envs.VLLM_USE_MERGE_ATTN_STATES_OPT and supported_dtypes(output) + and supported_headdim(output)): + from vllm._custom_ops import merge_attn_states + return merge_attn_states(output, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse) + else: + from vllm.attention.ops.triton_merge_attn_states import ( + merge_attn_states) + return merge_attn_states(output, prefix_output, prefix_lse, + suffix_output, suffix_lse, output_lse) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py new file mode 100644 index 0000000..29fa432 --- /dev/null +++ b/vllm/attention/ops/nki_flash_attn.py @@ -0,0 +1,903 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.language as nl +import numpy as np +import torch +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + +from vllm.utils import cdiv + + +def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + +@nki.jit +def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile): + """ + Load block tables from HBM into SRAM + + `block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`. + In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension. + """ + B_P_SIZE = 128 + + # reshape as `(num_tiles, num_blocks_per_tile)` + assert len(block_tables_hbm.shape) == 1 + (num_total_blocks, ) = block_tables_hbm.shape + assert num_blocks_per_tile * num_tiles == num_total_blocks + block_tables_hbm = block_tables_hbm.reshape( + (num_tiles, num_blocks_per_tile)) + + block_tables_sbuf = nl.zeros( + (cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile), + dtype=nl.int32, + ) + for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(num_blocks_per_tile)[None, :] + block_tables_sbuf[i, i_p, i_f] = nl.load( + block_tables_hbm[i_p + i * B_P_SIZE, i_f], + dtype=nl.int32, + mask=(i_p + i * B_P_SIZE < num_tiles), + ) + return block_tables_sbuf + + +@nki.jit +def transform_block_tables_for_indirect_load( + block_tables, + block_size_tiling_factor, + num_head, + head_id, +): + """ + This function does two things: + 1. calculate new `block_tables` for a `head_id` after flattening + `num_block`, `num_head`, and `block_size_tiling_factor` dimensions + 2. transpose the result so that `block_table` for each tile is mapped to + SBUF Partition dimension for vectorized DMA + + Tiling trick to further improve DMA performance: + Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M + blocks of a given `head_id` from HBM, the load `cache[block_tables, + head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not + fully utilize hardware parallelization. The solution is to tile `block_size` + into `(block_size_tiling_factor, tiled_block_size)` s.t. `M * + block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape + `(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`. + + Note: + We don't further tile D dimension as small DMA size also hurts performance. + """ + B_P_SIZE = 128 + num_partitions, num_tiles_per_partition, num_blocks_per_tile = ( + block_tables.shape) + assert num_tiles_per_partition == B_P_SIZE + assert is_power_of_2( + num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2" + + num_loads = cdiv(num_blocks_per_tile, B_P_SIZE) + block_tables_transposed = nl.ndarray( + ( + num_loads, + par_dim(B_P_SIZE), + num_partitions * num_tiles_per_partition, + ), + dtype=nl.int32, + ) + + # prepare iota ahead of time to avoid repeatedly using Gpsimd + if num_head > 1: + head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1)) + head_id = nl.transpose( + head_id.broadcast_to((1, num_tiles_per_partition))) + if num_blocks_per_tile > 1: + head_id = head_id.broadcast_to( + (num_tiles_per_partition, num_blocks_per_tile)) + + if block_size_tiling_factor > 1: + broadcast_shape = ( + num_tiles_per_partition, + num_blocks_per_tile, + block_size_tiling_factor, + ) + offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :], + dtype=nl.int32).broadcast_to(broadcast_shape) + + for partition_id in nl.affine_range(num_partitions): + block_tables_partition = block_tables[partition_id] + if num_head > 1: + # fuse num_block and num_head dimension + block_tables_partition = block_tables_partition * num_head + head_id + + if block_size_tiling_factor > 1: + # need to apply block size tiling trick + assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE + block_tables_partition = ((block_tables_partition * + block_size_tiling_factor).reshape( + (num_tiles_per_partition, + num_blocks_per_tile, + 1)).broadcast_to(broadcast_shape)) + new_block_tables = block_tables_partition + offset + new_block_tables = new_block_tables.reshape( + (num_tiles_per_partition, B_P_SIZE)) + else: + new_block_tables = block_tables_partition + + # transpose the block table so that it can be used by vector DGE + for i in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = (partition_id * num_tiles_per_partition + + nl.arange(num_tiles_per_partition)[None, :]) + block_tables_transposed[i, i_p, i_f] = nl.transpose( + new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)]) + return block_tables_transposed + + +@nki.jit +def load_kv_tile_from_cache( + cur_k_tile, + cur_v_tile, + kv_cache, + block_tables, + large_k_tile_idx, + num_blocks_per_large_tile, + tiled_block_size, + B_P_SIZE, + B_D_SIZE, +): + """ + Load KV cache and transform Key and Value into layout required by Matmul + + Vectorized DMA Load layout: + Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + + Layout used by attention matmuls: + Key: (par_dim(B_D_SIZE), seqlen_kv) + Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE) + equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE) + """ + # load key cache + num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) + for load_idx in nl.affine_range(num_loads): + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_k_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_k_tile.dtype) + # Transpose SBUF tensor using PE + for tb_i in nl.affine_range(tiled_block_size): + cur_k_tile[ + :, + nl.ds( + load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE, + B_P_SIZE, + ), + ] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)]) + + # load value cache + for load_idx in nl.affine_range(num_loads): + loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p, + large_k_tile_idx], i_f]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + i_p = nl.arange(B_P_SIZE)[:, None] + i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :] + cur_v_tile[ + :, + nl.ds( + load_idx * tiled_block_size * B_D_SIZE, + tiled_block_size * B_D_SIZE, + ), + ] = loaded + + +@nki.jit +def transpose_p_local(p_local_transposed, + p_local, + LARGE_TILE_SZ, + B_F_SIZE=512): + for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.sbuf, + dtype=p_local.dtype) + else: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.psum, + dtype=np.float32) + + for j in nl.affine_range(B_F_SIZE // 128): + j_128_slice = nl.ds(j * 128, 128) + i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) + + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( + p_local[:, i_j_128_slice]) + else: + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local[:, i_j_128_slice]) + + p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( + p_local_t_tmp, dtype=p_local_transposed.dtype) + + +@nki.jit +def _flash_attention_core( + q_local_tile, + k, + v, + o_buffer, + l_buffer, + m_buffer, + kernel_dtype, + acc_type, + tile_mask, + use_causal_mask, + q_tile_idx=None, + initialize=False, + LARGE_TILE_SZ=2048, + B_P_SIZE=128, + B_F_SIZE=512, + B_D_SIZE=128, + qk_res_buffer=None, +): + """ + The flash attention core function to calculate self attention between a tile + of q and a block of K and V. + The q_local_tile has (B_P_SIZE, B_D_SIZE) + The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will + be split into size B_F_SIZE tiles + + The results are stored in the following three buffers + o_buffer: (B_P_SIZE, d) + l_buffer: (B_P_SIZE, 1) + m_buffer: (B_P_SIZE, 1) + + All IO buffers are in SBUF. + """ + num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE + + qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + buffer=nl.sbuf, + dtype=acc_type) + max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), + dtype=acc_type) + for k_i in nl.affine_range(num_k_tile_per_large_tile): + k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) + + if use_causal_mask: + # mask are used to only apply computation to the lower half of the + # matrix, which reduce the arithmetic intensity by up to 50% + multiplication_required_selection = (q_tile_idx * B_P_SIZE + >= k_i * B_F_SIZE) + else: + multiplication_required_selection = True + + if multiplication_required_selection: + qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), + dtype=np.float32, + buffer=nl.psum) # (128, 512) + qk_psum[:, :] = nl.matmul(q_local_tile, + k[:, k_i_b_f_slice], + transpose_x=True) # (p(128), 512) + qk_res_buf[:, k_i_b_f_slice] = nl.where( + tile_mask[:, k_i_b_f_slice], + qk_psum[:, nl.ds(0, B_F_SIZE)], + -9984.0, + dtype=acc_type, + ) + else: + qk_res_buf[:, k_i_b_f_slice] = -9984.0 + + # Calculate max of the current tile + max_local[:, k_i] = nisa.tensor_reduce( + np.max, + qk_res_buf[:, k_i_b_f_slice], + axis=(1, ), + dtype=acc_type, + negate=False, + ) + + if qk_res_buffer is not None: + qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) + + max_ = nisa.tensor_reduce( + np.max, + max_local[:, :], + axis=(1, ), + dtype=acc_type, + negate=False, + ) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), + dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) + m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation( + np.exp, + m_previous, + bias=-1 * m_current, + scale=1.0, + ) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha) + + p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + + p_partial_sum = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), + dtype=acc_type, + ) + + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) + + # compute exp(qk - max) + # Compute partial row - tile sum of exp(qk - max)) + # FIXME : Use activation accumulate to accumulate over k_r_i loop ? + p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( + np.exp, + qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, + scale=1.0, + reduce_op=nl.add, + reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype, + ) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type) + + p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + transpose_p_local( + p_local_transposed=p_local_transposed, + p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_F_SIZE=B_F_SIZE, + ) + + pv_psum = nl.zeros( + (par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum, + ) + for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + pv_psum[:, :] += nl.matmul( + p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)], + transpose_x=True, + ) # (128, 128) (p(Br), d) + + if initialize: + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) + else: + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum) + + l_prev = l_buffer[:, 0] + l_exp = nl.add( + nl.exp(nl.subtract(l_prev, m_current)), + ps, + ) + l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp)) + + +@nki.jit +def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ): + B_P_SIZE = 128 + B_D_SIZE = v_hbm_tile.shape[-1] + loaded = nl.load(v_hbm_tile[ + nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), + :, + ]) + if cur_v_tile.dtype != loaded.dtype: + loaded = nl.copy(loaded, dtype=cur_v_tile.dtype) + cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded + + +@nki.jit +def flash_paged_attention( + query, + key, + value, + kv_cache, + block_tables, + mask, + softmax_scale=None, + mixed_precision=True, + LARGE_TILE_SZ=2048, + return_debug_tensors=False, +): + """ + Flash PagedAttention Forward Kernel. + + IO tensor layouts: + - query: shape (1, n_heads, d, seq_q) + - key: shape (1, n_kv_heads, d, seq_k) + - value: shape (1, n_kv_heads, seq_v, d) + - kv_cache: (2, num_blocks, n_kv_heads, block_size, d) + - block_tables: (num_active_blocks, ) + - mask: (seq_q, num_active_blocks * block_size + seq_q) + - o: shape (1, n_heads, seq_q, d) + + - This kernel requires seq_k == seq_v + - We use continuous batching by default, so the batch dimension is + always 1, and different requests are concatenated along sequence + dimension. + - We use paged cache blocks (kv_cache) to store KV cache. + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype except for + block_tables (int32) and mask (int32) + - If mixed_precision is True, then all Tensor Engine operation will be + performed in bfloat16 and accumulation will be performed in float32. + Otherwise the intermediates will be in the same type as the inputs. + + Compile-time Constants: + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` + - mixed_precision: flag to set non-matmul ops in fp32 precision, default + is set to `true`, if false, we use same precision as input types + - LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention + computation reduction + + GQA support Notes: + the spmd kernel for launching kernel should be on kv_heads instead of + nheads + + Example usage: + MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] + usage: `flash_fwd[b, h](q, k, v, ...)` + GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] + usage: `flash_fwd[b, kv_h](q, k, v, ...)` + """ + B_F_SIZE = 512 + B_P_SIZE = 128 + b, h, d, seqlen_q = query.shape + B_D_SIZE = d + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + _, num_blocks, k_h, block_size, _ = kv_cache.shape + q_h_per_k_h = h // k_h + assert b == 1, f"invalid batch size {b=}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}" + cache_shape = (2, num_blocks, k_h, block_size, d) + assert (tuple(kv_cache.shape) == cache_shape + ), f"{kv_cache.shape=} mismatch, expect {cache_shape}" + assert key is None or tuple(key.shape) == ( + 1, + k_h, + d, + seqlen_q, + ), f"key shape {key.shape} mismatch!" + assert value is None or tuple(value.shape) == ( + 1, + k_h, + seqlen_q, + d, + ), f"value shape {value.shape} mismatch!" + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert ( + LARGE_TILE_SZ % B_F_SIZE == 0 + ), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert is_power_of_2( + num_blocks_per_large_tile + ), f"{num_blocks_per_large_tile=} is expected of be power of 2" + if seqlen_q > B_F_SIZE: + MAX_REDUCTION_TILE = 2048 + if seqlen_q // 2 > MAX_REDUCTION_TILE: + assert ( + seqlen_q % MAX_REDUCTION_TILE == 0 + ), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}" + else: + assert (seqlen_q % B_F_SIZE == 0 + ), f"{seqlen_q=} should be divisible by {B_F_SIZE=})" + + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + num_large_k_tile = context_kv_len // LARGE_TILE_SZ + + o = nl.ndarray((b, h, seqlen_q, d), + dtype=query.dtype, + buffer=nl.shared_hbm) + hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( + None, + None, + None, + None, + ) + if return_debug_tensors: + hbm_l_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_m_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + qk_res_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + block_tables_sbuf = load_block_tables( + block_tables_hbm=block_tables, + num_tiles=num_large_k_tile, + num_blocks_per_tile=num_blocks_per_large_tile, + ) + + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient + if num_blocks_per_large_tile < B_P_SIZE: + # we checked num_blocks_per_tile is a power of 2 + assert B_P_SIZE % num_blocks_per_large_tile == 0 + block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile + # We assume block_size >= block_size_tiling_factor + assert block_size % block_size_tiling_factor == 0 + else: + block_size_tiling_factor = 1 + tiled_block_size = block_size // block_size_tiling_factor + + # Indirect DMA load must be placed along Partition Dimension + block_tables_sbuf = transform_block_tables_for_indirect_load( + block_tables_sbuf, + block_size_tiling_factor=block_size_tiling_factor, + num_head=k_h, + head_id=head_id, + ) + + # Flatten KV cache to be 3D for loading into SBUF + new_cache_shape = ( + 2, + num_blocks * k_h * block_size_tiling_factor, + tiled_block_size * d, + ) + kv_cache = kv_cache.reshape(new_cache_shape) + + # Global Flash Attention accumulators + o_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + l_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + m_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile): + num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE) + cur_k_tile = nl.ndarray( + (par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype, + ) + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE), + dtype=kernel_dtype, + ) + load_kv_tile_from_cache( + cur_k_tile=cur_k_tile, + cur_v_tile=cur_v_tile, + kv_cache=kv_cache, + block_tables=block_tables_sbuf, + large_k_tile_idx=large_k_tile_idx, + num_blocks_per_large_tile=num_blocks_per_large_tile, + tiled_block_size=tiled_block_size, + B_P_SIZE=B_P_SIZE, + B_D_SIZE=B_D_SIZE, + ) + + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) + q_tile[:, :] = q_sbuf_tile * softmax_scale + + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + kernel_dtype=kernel_dtype, + acc_type=acc_type, + tile_mask=cur_mask, + use_causal_mask=False, + q_tile_idx=i, + initialize=large_k_tile_idx == 0, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + ) + + # compute attention between input query, key and value + if key is not None and value is not None: + B_F_SIZE = min(seqlen_q, B_F_SIZE) + LARGE_TILE_SZ = seqlen_q + + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE), + dtype=kernel_dtype, + ) + + loaded = nl.load(key[batch_id, head_id, :, :]) + if loaded.dtype != kernel_dtype: + loaded = nl.copy(loaded, dtype=kernel_dtype) + cur_k_tile[:, :] = loaded + + v_hbm_tile = value[batch_id, head_id] + for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + load_v_tile( + v_hbm_tile=v_hbm_tile, + cur_v_tile=cur_v_tile, + large_tile_idx=0, + v_i=v_i, + LARGE_TILE_SZ=LARGE_TILE_SZ, + ) + + for i in nl.affine_range(n_tile_q): + cur_mask = nl.load(mask[ + nl.ds(i * B_P_SIZE, B_P_SIZE), + nl.ds(context_kv_len, LARGE_TILE_SZ), + ]) + for i_q_h in nl.affine_range(q_h_per_k_h): + + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load(q_hbm_tile[:, + nl.ds(i * + B_P_SIZE, B_P_SIZE)]) + if q_sbuf_tile.dtype != kernel_dtype: + q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype) + q_tile[:, :] = q_sbuf_tile * softmax_scale + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + kernel_dtype=kernel_dtype, + acc_type=acc_type, + tile_mask=cur_mask, + use_causal_mask=True, + q_tile_idx=i, + initialize=False, + LARGE_TILE_SZ=LARGE_TILE_SZ, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + qk_res_buffer=(qk_res_buffer[i, i_q_h] + if qk_res_buffer is not None else None), + ) + + # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + out = nl.multiply( + o_buffer[i, i_q_h], + nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]), + dtype=kernel_dtype, + ) + + nl.store( + o[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + :, + ], + out, + ) + # maximum and summation statistics + if return_debug_tensors: + nl.store( + hbm_m_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + m_buffer[i, i_q_h, :, :], + ) + nl.store( + hbm_l_buffer[ + batch_id, + head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), + ], + l_buffer[i, i_q_h], + ) + nl.store( + hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], + qk_res_buffer[batch_id, i_q_h, :, :], + ) + + if return_debug_tensors: + return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res + return o + + +def reorder_context_mask(mask, LARGE_TILE_SZ, block_size): + """ + Reorder the mask to make it compatible with the flash attention kernel. + + We vectorize KV cache read to improve DMA utilization. However, the layout + that maximizes DMA bandwidth changes the order tokens are consumed. + + The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE, + tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And + each step the engine consumes a column (rather than a row) of B_P_SIZE + tokens. Therefore, the tokens are visited in a strided way. + + To make sure mask matches the order tokens are consumed, we need to properly + transpose mask. + """ + total_query_len, total_seq_len = mask.shape + context_kv_len = total_seq_len - total_query_len + + B_P_SIZE = 128 + assert (LARGE_TILE_SZ + >= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}" + num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size) + tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks + if tiled_block_size > 1: + # Mask reordering is needed when tiled_block_size > 1 + device = mask.device + mask = mask.cpu() + context_mask = mask[:, :context_kv_len] + context_mask = context_mask.view( + total_query_len, + context_kv_len // LARGE_TILE_SZ, + num_tiled_blocks // B_P_SIZE, + B_P_SIZE, + tiled_block_size, + ) + context_mask = context_mask.transpose(3, 4).reshape( + total_query_len, context_kv_len) + new_mask = mask[:, context_kv_len:] + return torch.concat([context_mask, new_mask], dim=1).to(device) + else: + return mask + + +def flash_attn_varlen_nkifunc( + query, + key, + value, + kv_cache, + block_table, + attn_mask, + n_kv_head=None, + head_size=None, + LARGE_TILE_SZ=2048, + mixed_precision=True, +): + """ + Compute flash paged attention for variable length sequences. + + This function is a wrapper around the flash attention NKI kernel. It takes + in the following arguments: + - query: (1, n_heads, d, seq_q) + - key: (1, n_kv_heads, d, seq_k) + - value: (1, n_kv_heads, seq_v, d) + - kv_cache: (2, n_blocks, n_kv_heads, block_size, d) + - block_tables: (n_active_blocks, ) + - attn_mask: (seq_q, n_active_blocks * block_size + seq_q) + + Notes: + - attn_mask must be reordered outside using `reorder_context_mask` + - Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d) + for better DMA throughput + """ + if n_kv_head is None: + n_kv_head = kv_cache.shape[2] + assert kv_cache.shape[0] == 2 + assert kv_cache.shape[2] == n_kv_head + if head_size is None: + head_size = kv_cache.shape[-1] + + kwargs = dict( + query=query, + key=key, + value=value, + kv_cache=kv_cache, + block_tables=block_table, + mask=attn_mask, + softmax_scale=1.0 / (head_size**0.5), + mixed_precision=mixed_precision, + LARGE_TILE_SZ=LARGE_TILE_SZ, + ) + + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ + Writes key-value pairs to the KV cache at specified positions. + + Args: + key (torch.Tensor): Key tensor with shape + (num_tokens, n_kv_head, d_head) + value (torch.Tensor): Value tensor with shape + (num_tokens, n_kv_head, d_head) + kv_cache (torch.Tensor): Key/value cache tensor with shape + (2, num_blocks, n_kv_head, block_size, d_head) + slot_mapping (torch.Tensor): Mapping tensor indicating cache positions + with shape (num_tokens) + + Returns: + None: Updates the kv_cache tensor in-place + """ + block_size = kv_cache.size(3) + n_kv_head = key.size(1) + + # Calculate indices with explicit floor division + block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + block_offsets = slot_mapping % block_size + + # Create the head indices tensor + head_indices = torch.arange(n_kv_head, device=key.device) + + # Update caches using index_put_ + kv_cache.index_put_( + (torch.tensor([0], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), key) + + kv_cache.index_put_( + (torch.tensor([1], device=key.device), block_indices[:, None], + head_indices[None, :], block_offsets[:, None]), value) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py new file mode 100644 index 0000000..3008dc9 --- /dev/null +++ b/vllm/attention/ops/paged_attn.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import HAS_TRITON +import vllm.envs as envs +from vllm.utils import SUPPORT_TC + +if HAS_TRITON: + from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 +use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and SUPPORT_TC + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + ''' + CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size] + Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache layout: [num_blocks, num_kv_heads, head_size, block_size] + ''' + if envs.VLLM_USE_FLASH_ATTN_PA: + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache=value_cache.view(num_blocks, num_kv_heads,head_size, -1) + else: + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if envs.VLLM_USE_FLASH_ATTN_PA: + ops.reshape_and_cache_cuda( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + attn_masks: Optional[torch.Tensor] = None, + attn_masks_stride: int = 0 + ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + + if use_tc and head_size==128: + if envs.VLLM_USE_PA_PRINT_PARAM: + print("PA V1 SIZE:") + print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") + print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") + if attn_masks is None: + ops.paged_attention_v1_opt_tc( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step + ) + else: + ops.paged_attention_v1_opt_tc_with_mask( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + attn_masks, + attn_masks_stride + ) + return output + + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + + if use_v1: + # Run PagedAttention V1. + if envs.VLLM_USE_PA_PRINT_PARAM: + print("PA V1 SIZE:") + print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") + print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") + + if envs.VLLM_USE_OPT_OP: + if attn_masks is None: + ops.paged_attention_v1_opt( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step + ) + else: + ops.paged_attention_v1_opt_with_mask( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + attn_masks, + attn_masks_stride + ) + else: + if attn_masks is None: + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step + ) + else: + ops.paged_attention_v1_with_mask( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + attn_masks, + attn_masks_stride + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + + if envs.VLLM_USE_PA_PRINT_PARAM: + print("PA V2 SIZE:") + print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}") + print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") + print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") + + if envs.VLLM_USE_OPT_OP: + if attn_masks is None: + ops.paged_attention_v2_opt( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step + ) + else: + ops.paged_attention_v2_opt_with_mask( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + attn_masks, + attn_masks_stride + ) + else: + if attn_masks is None: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step + ) + else: + ops.paged_attention_v2_with_mask( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + attn_masks, + attn_masks_stride + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> torch.Tensor: + output = torch.empty_like(query) + max_seq_len = None + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc, + seq_lens_tensor, + max_seq_len, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) \ No newline at end of file diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py new file mode 100644 index 0000000..e7d727a --- /dev/null +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +from vllm.utils import cdiv + + +def _kv_cache_update_kernel( + # Prefetch + slices_ref, # [3, padded_num_slices], list of (kv_cache_start, + # new_kv_start, slice_len) + # Input + new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] + kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, + # head_dim] + # Output + _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + # Scratch + scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, + # head_dim] + sem, +): + async_copies = [] + block_idx = pl.program_id(0) + num_slices_per_block = scratch.shape[0] + + # Copy from new_kv_hbm_ref to scratch + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + new_kv_start = slices_ref[1, offset_i] + length = slices_ref[2, offset_i] + async_copy = pltpu.make_async_copy( + new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], + scratch.at[i, pl.ds(0, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + + for async_copy in async_copies: + async_copy.wait() + + # Copy from scratch to kv_cache_hbm_ref + async_copies.clear() + for i in range(num_slices_per_block): + offset_i = i + block_idx * num_slices_per_block + kv_cache_start = slices_ref[0, offset_i] + length = slices_ref[2, offset_i] + async_copy = pltpu.make_async_copy( + scratch.at[i, pl.ds(0, length), ...], + kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], + sem, + ) + async_copy.start() + async_copies.append(async_copy) + for async_copy in async_copies: + async_copy.wait() + + +@functools.partial( + jax.jit, + static_argnames=["page_size", "num_slices_per_block"], +) +def kv_cache_update( + new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] + slices: jax. + Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) + kv_cache: jax. + Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + num_kv_update_slices: jax.Array, # [1] + *, + page_size: int = 32, + num_slices_per_block: int = 8, +): + assert slices.shape[1] % num_slices_per_block == 0 + _, num_combined_kv_heads, head_dim = new_kv.shape + assert kv_cache.shape[1] == num_combined_kv_heads + assert kv_cache.shape[2] == head_dim + assert head_dim % 128 == 0 + # TODO: Add dynamic check to make sure that the all the slice lengths are + # smaller or equal to page_size + + in_specs = [ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ] + + out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] + out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] + + scalar_prefetches = [slices] + scratch = pltpu.VMEM( + (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), + new_kv.dtype, + ) + + scratch_shapes = [ + scratch, + pltpu.SemaphoreType.DMA, + ] + + kernel = pl.pallas_call( + _kv_cache_update_kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=len(scalar_prefetches), + in_specs=in_specs, + out_specs=out_specs, + grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ), + scratch_shapes=scratch_shapes, + ), + out_shape=out_shape, + input_output_aliases={len(scalar_prefetches) + 1: 0}, + ) + + return kernel(*scalar_prefetches, new_kv, kv_cache)[0] diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py new file mode 100644 index 0000000..0745b28 --- /dev/null +++ b/vllm/attention/ops/prefix_prefill.py @@ -0,0 +1,906 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +# Static kernels parameters +# BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +# NUM_WARPS = 4 if current_platform.is_rocm() else 8 + +BASE_BLOCK = 32 if current_platform.has_device_capability(80) else 32 +NUM_WARPS = 8 + + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + + +# Here's an example autotuner config for this kernel. This config does provide +# a performance improvement, but dramatically increases first call latency in +# triton 3.2. Because of this tradeoff, it's currently commented out. +# @triton.autotune( +# configs=[ +# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, \ +# "num_unroll_cache": 4, \ +# "num_unroll_request": 1 } | \ +# ({"kpack": 2, "waves_per_eu": 2} \ +# if current_platform.is_rocm() else {}), \ +# num_warps=4, \ +# num_stages=1) +# ], +# key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] +# ) +@triton.jit +def _fwd_kernel(Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0): + + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [BLOCK_SIZE]; starts at 0 + offs_bs_n = tl.arange(0, BLOCK_SIZE) + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ + loop_unroll_factor=num_unroll_cache): + start_n = tl.multiple_of(start_n, BLOCK_SIZE) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s) + # [D,BLOCK_SIZE] + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + + # [BLOCK_SIZE,D] + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl) + + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + else: + k_load = tl.load(K_cache + off_k) + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_bs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ + BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + else: + v_load = tl.load(V_cache + off_v) + + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in tl.range(0, \ + block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ + loop_unroll_factor=num_unroll_request): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, + qk, -10000) + + # compute running maximum + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + alpha = tl.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) + # update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + return + + +@triton.jit +def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load(Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + +@triton.jit +def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SKIP_DECODE: tl.constexpr, +): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len + + if SKIP_DECODE and cur_batch_query_len == 1: + return + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = ( + bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = (bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange(0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, + float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) + < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc = tl.dot(p, v, acc=acc, input_precision='ieee') + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + +@torch.inference_mode() +def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False): + + q_dtype_is_f32 = q.dtype is torch.float32 + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()] + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = current_platform.fp8_dtype() + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + assert batch + 1 == len(b_start_loc) + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + # batch, head, + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SKIP_DECODE=skip_decode, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + max_seq_len = 0 if max_seq_len is None else max_seq_len + extra_kargs = {} + if current_platform.is_rocm(): + extra_kargs = {"kpack": 2, "waves_per_eu": 2} + + grid = lambda META: (batch, head, + triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_SIZE=v_cache.shape[3], + num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + SLIDING_WINDOW=sliding_window, + SKIP_DECODE=skip_decode, + BLOCK_M=128, + BLOCK_N=64, + num_unroll_cache=4, + num_unroll_request=1, + num_warps=4, + num_stages=1, + **extra_kargs) + return diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py new file mode 100644 index 0000000..cce6b46 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +def get_aiter_mla_metadata(max_batch_size: int, block_size: int, + max_block_per_batch: int, + device: torch.device) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, + dtype=torch.int32, + device=device) + paged_kv_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int32, + device=device) + paged_kv_last_page_lens = torch.full((max_batch_size, ), + block_size, + dtype=torch.int32) + qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr + + +def aiter_mla_decode_fwd( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + logit_cap: float = 0.0, +): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + from aiter.mla import mla_decode_fwd + + mla_decode_fwd(q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py new file mode 100644 index 0000000..ad97152 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import aiter as rocm_aiter +import torch + +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.platforms import current_platform +from vllm.utils import cdiv + +FP8_DTYPE = current_platform.fp8_dtype() + + +class AITERPagedAttention(PagedAttention): + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + ) -> None: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + else: + kv_cache_torch_dtype = (FP8_DTYPE + if "fp8" in kv_cache_dtype else torch.int8) + key_cache = key_cache.view(kv_cache_torch_dtype) + value_cache = value_cache.view(kv_cache_torch_dtype) + + rocm_aiter.reshape_and_cache_with_pertoken_quant( + key, value, key_cache, value_cache, k_scale, v_scale, + slot_mapping.flatten(), True) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: + return PagedAttention.forward_decode( + query=query, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + kv_cache_dtype=kv_cache_dtype, + num_kv_heads=num_kv_heads, + scale=scale, + alibi_slopes=alibi_slopes, + k_scale=k_scale, + v_scale=v_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step) + + if "fp8" in kv_cache_dtype: + key_cache = key_cache.view(torch.float8_e4m3fnuz) + value_cache = value_cache.view(torch.float8_e4m3fnuz) + + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + max_num_blocks_per_seq = cdiv(max_seq_len, block_size) + + rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, + seq_lens, max_num_blocks_per_seq, k_scale, + v_scale, output) + return output diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py new file mode 100644 index 0000000..2b66870 --- /dev/null +++ b/vllm/attention/ops/triton_decode_attention.py @@ -0,0 +1,1614 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +# which was originally adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +# Changes: +# - Add support for page size >= 1. + +# Copyright 2025 vLLM Team +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size >= 1. +""" + +import os +import logging + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm import envs + +is_hip_ = current_platform.is_rocm() +os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" + +logger = logging.getLogger(__name__) + +# Only print the following warnings when triton version < 3.2.0. +# The issue won't affect performance or accuracy. +if triton.__version__ < '3.2.0': + logger.warning( + "The following error message 'operation scheduled before its operands' " + "can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 64 if not is_hip_ else 8 + + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = q.shape[0], q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + num_warps = 4 + if kv_group_num != 1: + num_warps = 1 if is_hip_ else 2 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if kv_group_num > BLOCK_H: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ + None, :] + q = tl.load(Q + offs_q, + mask=(mask_h[:, None]) & (mask_d[None, :]), + other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, + mask=(mask_h[:, None]) & (mask_dpe[None, :]), + other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & + (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), + qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + B_Seqlen, + num_kv_splits, + sm_scale, + page_size, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + num_stages = 2 + if is_hip_: + # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + num_stages = 1 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=4, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + o, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, + cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, + mask=mask_d, + other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + o + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = { + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, + num_kv_splits) + + + +# opt +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# ], +# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh"] +# ) +@triton.jit +def _decode_v1_kernel_stage1_use_tc( + Q, + K_Buffer, + sm_scale, + Req_to_tokens, + #B_req_idx, + B_Start_Loc, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + att_stride_h, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + SPLIT_K: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_k_id = tl.program_id(2) + + # reduce_dtype = Att_Out.dtype.element_ty + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + # cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_req_idx = cur_batch + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load( + Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 + ) # .to(reduce_dtype) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0) # .to(reduce_dtype) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) + split_k_start = kv_len_per_split * split_k_id + split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_k_start, split_k_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, + mask=offs_n < split_k_end, + other=0, + ) + k_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_buf_k = ( + k_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), + other=0.0, + ) # .to(reduce_dtype) + qk = tl.dot(q, k) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + k_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=offs_n[None, :] < split_k_end, + other=0.0, + ) # .to(reduce_dtype) + qk += tl.dot(qpe, kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + offs_o = cur_head[:, None] * att_stride_h + ( + cur_batch_in_all_start_index + offs_n[None, :] + ) + + tl.store( + Att_Out + offs_o, + qk, + mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), + ) + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_N": 8}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 8}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 8}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 8}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 16}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 16}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 16}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 16}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 32}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 64}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 128}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 256}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 512}, num_warps=1, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 512}, num_warps=2, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 512}, num_warps=4, num_ldmatrixes=0, num_stages=1), +# triton.Config({"BLOCK_N": 512}, num_warps=8, num_ldmatrixes=0, num_stages=1), +# ], +# key=["B_Seqlen","stride_logic_h","stride_buf_vbs","stride_buf_vh"] +# ) +@triton.jit +def _decode_v1_kernel_stage2_use_tc( + logits, + V_Buffer, + Out, + Req_to_tokens, + #B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_logic_h, + stride_buf_vbs, + stride_buf_vh, + stride_obs, + stride_oh, + stride_req_to_token_b, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_H: tl.constexpr, + PAGE_SIZE: tl.constexpr, + Lv: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_kv_head + 1) * kv_group_num + mask_h = mask_h & (cur_head < q_head_num) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = cur_batch #tl.load(B_req_idx + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] + v_ptrs = V_Buffer + offs_buf_v + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_page_number = tl.load( + Req_to_tokens + + cur_batch_req_idx * stride_req_to_token_b + + (start_n + offs_n) // PAGE_SIZE, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0, + ) + v_loc = v_page_number * PAGE_SIZE + (start_n + offs_n) % PAGE_SIZE + offs_qk = cur_head[:, None] * stride_logic_h + ( + cur_batch_start_loc + start_n + offs_n[None, :] + ) + + qk = tl.load( + logits + offs_qk, + mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), + other=float("-inf"), + ) #[head, block_n] + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + e_sum = e_sum * old_scale + tl.sum(p, 1) + v = tl.load( + v_ptrs + v_loc[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) #[block_n,head_dim] + p = p.to(v.dtype) + acc = acc * old_scale[:, None] + tl.dot(p, v) + e_max = n_e_max + + acc = acc / e_sum[:, None] + off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) + +def _decode_v1_stage1_use_tc( + q, + k_buffer, + att_out, + Req_to_tokens, + #B_req_idx, + B_Start_Loc, + B_Seqlen, + sm_scale, + page_size, + num_kv_splits, + best_config, + logit_cap, +): + Lk = k_buffer.shape[-1] + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + + # batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + + BLOCK_N = best_config['stage1']['BLOCK_N'] + SPLIT_K = num_kv_splits # best_config.SPLIT_K + num_stages = best_config['stage1']['num_stages'] + num_warps = best_config['stage1']['num_warps'] + + BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) + grid = lambda META: ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + SPLIT_K, + ) + _decode_v1_kernel_stage1_use_tc[grid]( + q, + k_buffer, + sm_scale, + Req_to_tokens, + #B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), + k_buffer.stride(-2), + att_out.stride(0), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + SPLIT_K=SPLIT_K, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=num_stages, + Lk=Lk, + kpack=2, + ) + + # return _decode_v1_kernel_stage1_use_tc.best_config + + +def _decode_v1_stage2_use_tc( + logits, + v_buffer, + o, + req_to_tokens, + #b_req_idx, + b_start_loc, + b_seq_len, + best_config, + page_size, +): + batch, head_num = b_seq_len.shape[0], logits.shape[0] + kv_group_num = logits.shape[0] // v_buffer.shape[-2] + + BLOCK_N = best_config['stage2']['BLOCK_N'] + num_stages = best_config['stage2']['num_stages'] + num_warps = best_config['stage2']['num_warps'] + BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) + + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + + _decode_v1_kernel_stage2_use_tc[grid]( + logits, + v_buffer, + o, + req_to_tokens, + #b_req_idx, + b_start_loc, + b_seq_len, + logits.stride(0), + v_buffer.stride(-3), + v_buffer.stride(-2), + o.stride(0), + o.stride(1), + req_to_tokens.stride(0), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + PAGE_SIZE=page_size, + Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, + ) + + # return _decode_v1_kernel_stage2_use_tc.best_config + + +def decode_attention_v1( + q, + k_buffer, + v_buffer, + o, + req_to_token, + #b_req_idx, + b_start_loc, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + best_config, + page_size, + logit_cap=0.0, +): + # GQA/MQA/MLA + # _decode_v1_stage1_best_config = _decode_v1_stage1_use_tc( + # q, + # k_buffer, + # attn_logits, + # req_to_token, + # #b_req_idx, + # b_start_loc, + # b_seq_len, + # sm_scale, + # page_size, + # num_kv_splits, + # logit_cap, + # ) + # _decode_v1_stage2_best_config = _decode_v1_stage2_use_tc( + # attn_logits, + # v_buffer, + # o, + # req_to_token, + # #b_req_idx, + # b_start_loc, + # b_seq_len, + # page_size, + # ) + # return _decode_v1_stage1_best_config, _decode_v1_stage2_best_config + _decode_v1_stage1_use_tc( + q, + k_buffer, + attn_logits, + req_to_token, + #b_req_idx, + b_start_loc, + b_seq_len, + sm_scale, + page_size, + num_kv_splits, + best_config, + logit_cap, + ) + _decode_v1_stage2_use_tc( + attn_logits, + v_buffer, + o, + req_to_token, + #b_req_idx, + b_start_loc, + b_seq_len, + best_config, + page_size, + ) + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=2, num_stages=1), +# triton.Config({"BLOCK_N": 16, "BLOCK_DIM":64}, num_warps=4, num_stages=1), +# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=2, num_stages=1), +# triton.Config({"BLOCK_N": 32, "BLOCK_DIM":64}, num_warps=4, num_stages=1), +# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=2, num_stages=1), +# triton.Config({"BLOCK_N": 64, "BLOCK_DIM":32}, num_warps=4, num_stages=1), +# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=2, num_stages=1), +# triton.Config({"BLOCK_N": 128, "BLOCK_DIM":32}, num_warps=4, num_stages=1), +# triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=2, num_stages=1), +# triton.Config({"BLOCK_N": 256, "BLOCK_DIM":32}, num_warps=4, num_stages=1), +# ], +# key=["B_Seqlen","stride_qbs","stride_buf_kbs","stride_buf_kh", "stride_buf_vbs", "stride_buf_vh"] +# ) +@triton.jit +def _decode_v2_kernel_stage1_use_tc( + Q, + K_Buffer, + V_Buffer, + sm_scale, + Req_to_tokens, + # B_req_idx, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DIM: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + # offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + # mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + # cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_req_idx = cur_batch + + # offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + # q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + NUM_DIM_SPLIT = tl.cdiv(BLOCK_DMODEL, BLOCK_DIM) + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + + qk = tl.zeros([BLOCK_H, BLOCK_N], dtype=tl.float32) + for i in range(0, NUM_DIM_SPLIT): + offs_d = tl.arange(0, BLOCK_DIM) + i * BLOCK_DIM + mask_d = offs_d < Lk + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None,:] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + offs_buf_k = kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None] + k = tl.load(K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0) + qk += tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_v2_stage1_use_tc( + q, + k_buffer, + v_buffer, + att_out, + Req_to_tokens, + # B_req_idx, + B_Seqlen, + num_kv_splits, + sm_scale, + best_config, + page_size, + logit_cap, +): + + BLOCK = best_config['stage1']['BLOCK_N'] + BLOCK_DIM = best_config['stage1']['BLOCK_DIM'] + num_stages = best_config['stage1']['num_stages'] + num_warps = best_config['stage1']['num_warps'] + + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + # batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = q.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[-2] + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + + grid = lambda META: ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + _decode_v2_kernel_stage1_use_tc[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + Req_to_tokens, + # B_req_idx, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(-3), + k_buffer.stride(-2), + v_buffer.stride(-3), + v_buffer.stride(-2), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_DIM=BLOCK_DIM, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + PAGE_SIZE=page_size, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + kpack=2, + ) + + # return _decode_v2_kernel_stage1_use_tc.best_config + + +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=1, num_stages=1), +# triton.Config({}, num_warps=1, num_stages=1), +# triton.Config({}, num_warps=2, num_stages=1), +# triton.Config({}, num_warps=4, num_stages=1), +# triton.Config({}, num_warps=8, num_stages=1), + +# ], +# key=["B_Seqlen", "stride_mid_ob", "stride_mid_oh", "stride_mid_os"] +# ) +@triton.jit +def _decode_v2_kernel_stage2( + Mid_O, + O, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_v2_stage2_use_tc( + logits, + q, + o, + v_buffer, + b_seq_len, + num_kv_splits, + best_config, +): + num_stages = best_config['stage2']['num_stages'] + num_warps = best_config['stage2']['num_warps'] + + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + grid = (batch, head_num, 1) + _decode_v2_kernel_stage2[grid]( + logits, + o, + b_seq_len, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, + ) + + # return _decode_v2_kernel_stage2.best_config + + +def decode_attention_v2( + q, + k_buffer, + v_buffer, + o, + req_to_token, + # b_req_idx, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + best_config, + page_size, + logit_cap=0.0, +): + # _decode_v2_stage1_best_config = _decode_v2_stage1_use_tc( + # q, + # k_buffer, + # v_buffer, + # attn_logits, + # req_to_token, + # # b_req_idx, + # b_seq_len, + # num_kv_splits, + # sm_scale, + # page_size, + # logit_cap, + # ) + # _decode_v2_stage2_best_config = _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + # return _decode_v2_stage1_best_config, _decode_v2_stage2_best_config + _decode_v2_stage1_use_tc( + q, + k_buffer, + v_buffer, + attn_logits, + req_to_token, + # b_req_idx, + b_seq_len, + num_kv_splits, + sm_scale, + best_config, + page_size, + logit_cap, + ) + _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + best_config, + page_size=1, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[-2] + b_start_loc = torch.arange(0, req_to_token.shape[0]*req_to_token.shape[1], req_to_token.shape[0]*req_to_token.shape[1] // q.shape[0], device="cuda").to(torch.int32) + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + else: + # GQA/MQA/MLA + if envs.VLLM_USE_TRITON_OPT_MLA: + ''' + decode_attention_v2( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) + attn_logits_v1 = torch.empty( + (q.shape[1],req_to_token.shape[0]*req_to_token.shape[1]*page_size), + dtype=torch.float32, + device="cuda") + decode_attention_v1( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_start_loc, + b_seq_len, + attn_logits_v1, + num_kv_splits, # sub + sm_scale, + page_size, + logit_cap, + )''' + num_b = min(kv_group_num, 16) + grid_num = (q.shape[1] + num_b - 1) // num_b * q.shape[0] + L = req_to_token.shape[1]*page_size + if grid_num * num_kv_splits < 128: + num_kv_splits = (127 + grid_num) // grid_num + attn_logits_v2 = torch.empty( + (q.shape[0], q.shape[1], num_kv_splits, v_buffer.shape[-1] + 1), + dtype=torch.float32, + device="cuda", + ) + + if best_config['kernel_kind'] == 'v1_2stages_tc': + attn_logits_v1 = torch.empty( + (q.shape[1],req_to_token.shape[0]*req_to_token.shape[1]*page_size), + dtype=torch.float32, + device="cuda") + decode_attention_v1( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_start_loc, + b_seq_len, + attn_logits_v1, + num_kv_splits, + sm_scale, + best_config=best_config['best_config'], + page_size=page_size, + logit_cap=logit_cap, + ) + elif best_config['kernel_kind'] == 'v2_tc': + decode_attention_v2( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits_v2, + num_kv_splits, + sm_scale, + best_config=best_config['best_config'], + page_size=page_size, + logit_cap=logit_cap, + ) + else: + print("Unknown mla kernel kind: ", best_config['kernel_kind']) + else: + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, + logit_cap, + ) \ No newline at end of file diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000..49070e4 --- /dev/null +++ b/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,984 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton + +# Avoid misleading ROCm warning. +if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx1x +else: + on_gfx1x = lambda *args, **kwargs: False + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if USE_FP8: + p *= p_descale + + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +def get_cdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': True + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_autotune_configs(): + if on_gfx1x(): + return get_rdna_autotune_configs() + else: + return get_cdna_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + +float8_info = torch.finfo(current_platform.fp8_dtype()) + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, + L, + Out, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, + USE_FP8_OUT: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + USE_FP8, + qk_scale, + p_descale, + ) + # epilogue + + if USE_FP8: + acc *= acc_scale + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if USE_FP8_OUT: + acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + fp8_scales=None, + fp8_out_scale=None, + ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + float8 = current_platform.fp8_dtype() + + def check_and_convert(t, scale): + if t.dtype != float8: + descale = 1.0 / scale + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) + return ts.to(float8) + else: + return t + + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = 1.0 + + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + p_descale = 1.0 / p_scale + o_descale = 1.0 / fp8_out_scale.item( + ) if fp8_out_scale is not None else 1.0 + + arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q + arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + USE_FP8=use_fp8, + USE_FP8_OUT=fp8_out_scale is not None, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py new file mode 100644 index 0000000..56d78ed --- /dev/null +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + + +# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +# can be used to combine partial attention results (in the split-KV case) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + + # FA2 and FA3 have different behavior for when the sum-exp is 0, this namely + # arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf. + # If we see an inf assume FA2 and convert inf to -inf for consistency + # and correctness. Inf generally doesn't make sense in this context outside + # of undefined-behavior/FA2-case, so I think this a safe assumption. + p_lse = float('-inf') if p_lse == float('inf') else p_lse + s_lse = float('-inf') if s_lse == float('inf') else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + # Will reuse precomputed Exp values for scale factor computation. + p_se = tl.exp(p_lse) + s_se = tl.exp(s_lse) + out_se = (p_se + s_se) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # NOTE(woosuk): Be careful with the numerical stability. + # We should compute the scale first, and then multiply it with the output. + # Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly. + p_scale = p_se / out_se + s_scale = s_se / out_se + out = p_out * p_scale + s_out * s_scale + tl.store(output + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py new file mode 100644 index 0000000..c65f095 --- /dev/null +++ b/vllm/attention/ops/triton_unified_attention.py @@ -0,0 +1,738 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Authors: +# - Burkhard Ringlein +# - Jan van Lunteren +# - Chih-Chieh Yang +# - Thomas Parnell + +import torch +import triton +import triton.language as tl + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.exp(Sdiv) + p2 = tl.exp(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@triton.jit +def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, + BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@triton.jit +def kernel_unified_attention_2d( + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles + for j in range(0, num_blocks): + + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + # epilogue + acc = acc / L[:, None] + + output_offset = (query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :]) + + tl.store( + output_ptr + output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + +@triton.jit +def kernel_unified_attention_3d( + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + q_block_global_idx = tl.program_id(0) + kv_head_idx = tl.program_id(1) + segm_idx = tl.program_id(2) + + seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, + BLOCK_Q, True) + + q_block_start_idx = tl.load(query_start_len_ptr + + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index \ + - cur_batch_in_all_start_index + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len: + return + + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_SIZE_PADDED) + + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + \ + offs_m % num_queries_per_kv + + query_offset = (query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + + dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) + query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) + + # Q : (BLOCK_M, HEAD_SIZE_PADDED) + Q = tl.load( + query_ptr + query_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + block_table_offset = seq_idx * block_table_stride + + M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32) + + # context length for this particular sequences + context_len = seq_len - cur_batch_query_len + + # alibi slope for this head + if USE_ALIBI_SLOPES: + alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, + mask=query_mask_1, + other=0.0) + + num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + + # iterate through tiles within current segment + for j in range( + segm_idx * blocks_per_segment, + min((segm_idx + 1) * blocks_per_segment, num_blocks), + ): + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) + + offs_n = tl.arange(0, BLOCK_SIZE) + + v_offset = (physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_n[:, None] * stride_v_cache_1) + + k_offset = (physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_n[None, :] * stride_k_cache_1) + + # K : (HEAD_SIZE, BLOCK_SIZE) + K_load = tl.load(key_cache_ptr + k_offset, + mask=dim_mask[:, None], + other=0.0) + + if K_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + K = K_load + else: + K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + else: + K = K_load + + # V : (BLOCK_SIZE, HEAD_SIZE) + V_load = tl.load(value_cache_ptr + v_offset, + mask=dim_mask[None, :], + other=0.0) + + if V_load.dtype.is_fp8(): + if Q.dtype.is_fp8(): + V = V_load + else: + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + else: + V = V_load + + seq_offset = j * BLOCK_SIZE + offs_n + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + # S : (BLOCK_M, BLOCK_SIZE) + S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32) + + S += scale * tl.dot(Q, K) + + if USE_SOFTCAP: + S = apply_softcap(S, softcap) + + S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, + S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = tl.where((context_len + query_pos[:, None] - seq_offset) + < SLIDING_WINDOW, S, float("-inf")) + + if USE_ALIBI_SLOPES: + S += alibi_slope[:, None] * (seq_offset - context_len) + + # compute running maximum + # m_j : (BLOCK_M,) + m_j = tl.maximum(M, tl.max(S, axis=1)) + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = tl.where(m_j > float("-inf"), m_j, 0.0) + + # P : (BLOCK_M, BLOCK_SIZE,) + P = tl.exp(S - m_j[:, None]) + + # l_j : (BLOCK_M,) + l_j = tl.sum(P, axis=1) + + # alpha : (BLOCK_M, ) + alpha = tl.exp(M - m_j) + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc = acc * alpha[:, None] + + # update constants + L = L * alpha + l_j + M = m_j + + # acc : (BLOCK_M, HEAD_SIZE_PADDED) + acc += tl.dot(P.to(V.dtype), V) + + segm_output_offset = ( + query_offset_0[:, None].to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + tl.store( + segm_output_ptr + segm_output_offset, + acc, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + segm_offset = (query_offset_0.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, + L, + mask=query_mask_0 & query_mask_1) + + +@triton.jit +def reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + #[num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, + BLOCK_Q, False) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, + 0).to(tl.int1) + + # load segment maxima + segm_offset = (query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)) + segm_max = tl.load(segm_max_ptr + segm_offset, + mask=segm_mask, + other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, + mask=segm_mask, + other=0.0) + segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) * + (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.exp(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + # write result + output_offset = (query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED)) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + block_size = v.shape[1] + assert q.element_size() >= 2 or block_size >= 32, \ + "Block size must be at least 32 for fp8" + + use_alibi_slopes = alibi_slopes is not None + + block_size = v.shape[1] + num_seqs = len(seqused_k) + num_query_heads = q.shape[1] + num_kv_heads = k.shape[2] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size = q.shape[2] + + BLOCK_M = 16 + BLOCK_Q = BLOCK_M // num_queries_per_kv + + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(q.shape[0] / BLOCK_Q) + num_seqs + total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs + + # if batch contains a prefill + if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: + kernel_unified_attention_2d[( + total_num_q_blocks, + num_kv_heads, + )]( + output_ptr=out, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + ) + else: + # for initial version, NUM_SEGMENTS = 16 is chosen as a default + # value that showed good performance in tests + NUM_SEGMENTS = 16 + + segm_output = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + q.shape[0], + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + + kernel_unified_attention_3d[( + total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_SOFTCAP=(softcap > 0), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) + + reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py new file mode 100644 index 0000000..df14aea --- /dev/null +++ b/vllm/attention/selector.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from contextlib import contextmanager +from functools import cache +from typing import Generator, Optional, Union + +import torch + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.platforms import _Backend, current_platform +from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None + + +def get_env_variable_attn_backend() -> Optional[_Backend]: + ''' + Get the backend override specified by the vLLM attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + ''' + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# VLLM_ATTENTION_BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: Optional[_Backend] = None + + +def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: + ''' + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + ''' + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> Optional[_Backend]: + ''' + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + ''' + return forced_attn_backend + + +def supports_head_size( + attn_backend: Union[str, type[AttentionBackend]], + head_size: int, +) -> bool: + if isinstance(attn_backend, str): + try: + attn_backend = resolve_obj_by_qualname(attn_backend) + except ImportError: + return False + + assert isinstance(attn_backend, type) + + # TODO: Update the interface once V0 is removed + if get_supported_head_sizes := getattr(attn_backend, + "get_supported_head_sizes", None): + return head_size in get_supported_head_sizes() + if validate_head_size := getattr(attn_backend, "validate_head_size", None): + try: + validate_head_size(head_size) + return True + except Exception: + return False + + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "head size validation") + + +def get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_mla: bool = False, +) -> type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + is_blocksparse=is_blocksparse, + use_v1=envs.VLLM_USE_V1, + use_mla=use_mla, + ) + + +@cache +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, + use_mla: bool = False, +) -> type[AttentionBackend]: + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend + + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + selected_backend = None + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + + # get device-specific attn_backend + attention_cls = current_platform.get_attn_backend_cls( + selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, + use_mla) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}") + return resolve_obj_by_qualname(attention_cls) + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: _Backend) -> Generator[None, None, None]: + ''' + Globally force a vLLM attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + ''' + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py new file mode 100644 index 0000000..95156f1 --- /dev/null +++ b/vllm/attention/utils/fa_utils.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +from vllm import envs +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +if current_platform.is_cuda(): + from vllm import _custom_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) +elif current_platform.is_rocm(): + from vllm import _custom_ops as ops + reshape_and_cache_cuda = ops.reshape_and_cache_cuda + from flash_attn import vllm_flash_attn_varlen_func +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash + flash_attn_varlen_func = ops.flash_attn_varlen_func + get_scheduler_metadata = ops.get_scheduler_metadata + + +def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: + # import here to avoid circular dependencies + from vllm.platforms import current_platform + if current_platform.is_xpu(): + return 2 + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) + device_capability = current_platform.get_device_capability() + + assert device_capability is not None + + # 1. default version depending on platform + fa_version = 3 if (device_capability.major == 9 + and is_fa_version_supported(3)) else 2 + + # 2. override if passed by environment + if envs.VLLM_FLASH_ATTN_VERSION is not None: + assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3] + fa_version = envs.VLLM_FLASH_ATTN_VERSION + + # 3. fallback for unsupported combinations + if device_capability.major == 10 and fa_version == 3: + logger.warning_once( + "Cannot use FA version 3 on Blackwell platform " + "defaulting to FA version 2.") + fa_version = 2 + + if requires_alibi and fa_version == 3: + logger.warning_once("Cannot use FA version 3 with ALiBi, " + "defaulting to FA version 2.") + fa_version = 2 + + if not is_fa_version_supported(fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + fa_version, fa_version_unsupported_reason(fa_version)) + + assert is_fa_version_supported(fa_version) + return fa_version + except (ImportError, AssertionError): + return None + + +def flash_attn_supports_fp8() -> bool: + return get_flash_attn_version() == 3 and \ + current_platform.get_device_capability().major == 9 + + +def is_flash_attn_varlen_func_available() -> bool: + return current_platform.is_cuda() or current_platform.is_rocm() or current_platform.is_xpu() diff --git a/vllm/beam_search.py b/vllm/beam_search.py new file mode 100644 index 0000000..f3bc421 --- /dev/null +++ b/vllm/beam_search.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union + +from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: list[int] + logprobs: list[dict[int, Logprob]] + lora_request: Optional[LoRARequest] = None + cum_logprob: float = 0.0 + text: Optional[str] = None + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + multi_modal_data: Optional["MultiModalDataDict"] = None + mm_processor_kwargs: Optional[dict[str, Any]] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: list[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__( + self, + prompt_tokens: list[int], + lora_request: Optional[LoRARequest] = None, + logprobs: Optional[list[dict[int, Logprob]]] = None, + **kwargs, + ): + self.beams: list[BeamSearchSequence] = [ + BeamSearchSequence( + tokens=prompt_tokens, + logprobs=[] if logprobs is None else list(logprobs), + lora_request=lora_request, + **kwargs, + ) + ] + self.completed: list[BeamSearchSequence] = [] + + +def get_beam_search_score( + tokens: list[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/vllm/benchmarks/__init__.py b/vllm/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py new file mode 100644 index 0000000..b3688d2 --- /dev/null +++ b/vllm/benchmarks/datasets.py @@ -0,0 +1,1441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This module defines a framework for sampling benchmark requests from various +datasets. Each dataset subclass of BenchmarkDataset must implement sample +generation. Supported dataset types include: + - ShareGPT + - Random (synthetic) + - Sonnet + - BurstGPT + - HuggingFace + - VisionArena +""" +import base64 +import io +import json +import logging +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cache +from io import BytesIO +from typing import Any, Callable, Optional, Union + +import numpy as np +from PIL import Image +from transformers import PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.image import convert_image_mode +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer +from vllm.utils import PlaceholderModule + +try: + from datasets import load_dataset +except ImportError: + datasets = PlaceholderModule("datasets") + load_dataset = datasets.placeholder_attr("load_dataset") + +try: + import pandas as pd +except ImportError: + pd = PlaceholderModule("pandas") + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Data Classes +# ----------------------------------------------------------------------------- + + +@dataclass +class SampleRequest: + """ + Represents a single inference request for benchmarking. + """ + + prompt: Union[str, Any] + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None + lora_request: Optional[LoRARequest] = None + + +# ----------------------------------------------------------------------------- +# Benchmark Dataset Base Class +# ----------------------------------------------------------------------------- + + +class BenchmarkDataset(ABC): + DEFAULT_SEED = 0 + IS_MULTIMODAL = False + + def __init__( + self, + dataset_path: Optional[str] = None, + random_seed: int = DEFAULT_SEED, + ) -> None: + """ + Initialize the BenchmarkDataset with an optional dataset path and random + seed. + + Args: + dataset_path (Optional[str]): Path to the dataset. If None, it + indicates that a default or random dataset might be used. + random_seed (int): Seed value for reproducible shuffling or + sampling. Defaults to DEFAULT_SEED. + """ + self.dataset_path = dataset_path + # Set the random seed, ensuring that a None value is replaced with the + # default seed. + self.random_seed = (random_seed + if random_seed is not None else self.DEFAULT_SEED) + self.data = None + + def apply_multimodal_chat_transformation( + self, + prompt: str, + mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + """ + Transform a prompt and optional multimodal content into a chat format. + This method is used for chat models that expect a specific conversation + format. + """ + content = [{"text": prompt, "type": "text"}] + if mm_content is not None: + content.append(mm_content) + return [{"role": "user", "content": content}] + + def load_data(self) -> None: + """ + Load data from the dataset path into self.data. + + This method must be overridden by subclasses since the method to load + data will vary depending on the dataset format and source. + + Raises: + NotImplementedError: If a subclass does not implement this method. + """ + # TODO (jenniferzhao): add support for downloading data + raise NotImplementedError( + "load_data must be implemented in subclasses.") + + def get_random_lora_request( + self, + tokenizer: PreTrainedTokenizerBase, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + ) -> tuple[Optional[LoRARequest], AnyTokenizer]: + """ + Optionally select a random LoRA request and return its associated + tokenizer. + + This method is used when LoRA parameters are provided. It randomly + selects a LoRA based on max_loras and retrieves a cached tokenizer for + that LoRA if available. Otherwise, it returns the base tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no + LoRA is selected. + max_loras (Optional[int]): The maximum number of LoRAs available. + If `None`, LoRA is not used. + lora_path (Optional[str]): Path to the LoRA parameters on disk. + If `None`, LoRA is not used. + + Returns: + A tuple with the following elements: + - A new [LoRARequest][] (or `None` if not applicable). + - The tokenizer associated with the LoRA request + (or the base tokenizer). + """ + if max_loras is None or lora_path is None: + return None, tokenizer + + # Generate a random LoRA ID in the range [1, max_loras]. + lora_id = random.randint(1, max_loras) + lora_request = LoRARequest( + lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(lora_path), + ) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + # Return lora_request and the cached tokenizer if available; otherwise, + # return the base tokenizer + return lora_request, lora_tokenizer_cache[lora_id] or tokenizer + + @abstractmethod + def sample(self, tokenizer: PreTrainedTokenizerBase, + num_requests: int) -> list[SampleRequest]: + """ + Abstract method to generate sample requests from the dataset. + + Subclasses must override this method to implement dataset-specific logic + for generating a list of SampleRequest objects. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used + for processing the dataset's text. + num_requests (int): The number of sample requests to generate. + + Returns: + list[SampleRequest]: A list of sample requests generated from the + dataset. + """ + raise NotImplementedError("sample must be implemented in subclasses.") + + def maybe_oversample_requests(self, requests: list[SampleRequest], + num_requests: int) -> None: + """ + Oversamples the list of requests if its size is less than the desired + number. + + Args: + requests (List[SampleRequest]): The current list of sampled + requests. + num_requests (int): The target number of requests. + """ + if len(requests) < num_requests: + random.seed(self.random_seed) + additional = random.choices(requests, + k=num_requests - len(requests)) + requests.extend(additional) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) + + +# ----------------------------------------------------------------------------- +# Utility Functions and Global Caches +# ----------------------------------------------------------------------------- + + +def is_valid_sequence( + prompt_len: int, + output_len: int, + min_len: int = 4, + max_prompt_len: int = 1024, + max_total_len: int = 2048, + skip_min_output_len_check: bool = False, +) -> bool: + """ + Validate a sequence based on prompt and output lengths. + + Default pruning criteria are copied from the original `sample_hf_requests` + and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as + from `sample_requests` in benchmark_throughput.py. + """ + # Check for invalid conditions + prompt_too_short = prompt_len < min_len + output_too_short = (not skip_min_output_len_check) and (output_len + < min_len) + prompt_too_long = prompt_len > max_prompt_len + combined_too_long = (prompt_len + output_len) > max_total_len + + # Return True if none of the invalid conditions are met + return not (prompt_too_short or output_too_short or prompt_too_long + or combined_too_long) + + +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +# Global cache for LoRA tokenizers. +lora_tokenizer_cache: dict[int, AnyTokenizer] = {} + + +def process_image(image: Any) -> Mapping[str, Any]: + """ + Process a single image input and return a multimedia content dictionary. + + Supports three input types: + + 1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key + containing raw image data. - Loads the bytes as a PIL.Image.Image. + + 2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as + a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns + a dictionary with the image as a base64 data URL. + + 3. String input: - Treats the string as a URL or local file path. - + Prepends "file://" if the string doesn't start with "http://" or + "file://". - Returns a dictionary with the image URL. + + Raises: + ValueError: If the input is not a supported type. + """ + if isinstance(image, dict) and 'bytes' in image: + image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, Image.Image): + image = convert_image_mode(image, "RGB") + with io.BytesIO() as image_data: + image.save(image_data, format="JPEG") + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + if isinstance(image, str): + image_url = (image if image.startswith( + ("http://", "file://")) else f"file://{image}") + return {"type": "image_url", "image_url": {"url": image_url}} + + raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes.") + + +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDataset(BenchmarkDataset): + # Default values copied from benchmark_serving.py for the random dataset. + DEFAULT_PREFIX_LEN = 0 + DEFAULT_RANGE_RATIO = 0.0 + DEFAULT_INPUT_LEN = 1024 + DEFAULT_OUTPUT_LEN = 128 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + random.seed(self.random_seed) + np.random.seed(self.random_seed) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + range_ratio: float = DEFAULT_RANGE_RATIO, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + **kwargs, + ) -> list[SampleRequest]: + # Enforce range_ratio < 1 + assert range_ratio < 1.0, ( + "random_range_ratio must be < 1.0 to ensure a valid sampling range" + ) + + vocab_size = tokenizer.vocab_size + num_special_tokens = tokenizer.num_special_tokens_to_add() + real_input_len = input_len - num_special_tokens + + prefix_token_ids = (np.random.randint( + 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + + # New sampling logic: [X * (1 - b), X * (1 + b)] + input_low = int(real_input_len * (1 - range_ratio)) + input_high = int(real_input_len * (1 + range_ratio)) + output_low = int(output_len * (1 - range_ratio)) + output_high = int(output_len * (1 + range_ratio)) + + # Add logging for debugging + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, input_high, output_low, output_high) + + input_lens = np.random.randint(input_low, + input_high + 1, + size=num_requests) + output_lens = np.random.randint(output_low, + output_high + 1, + size=num_requests) + offsets = np.random.randint(0, vocab_size, size=num_requests) + + requests = [] + for i in range(num_requests): + inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % + vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + prompt = tokenizer.decode(token_sequence) + # After decoding the prompt we have to encode and decode it again. + # This is done because in some cases N consecutive tokens + # give a string tokenized into != N number of tokens. + # For example for GPT2Tokenizer: + # [6880, 6881] -> ['Ġcalls', 'here'] -> + # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + # To avoid uncontrolled change of the prompt length, + # the encoded sequence is truncated before being decode again. + total_input_len = prefix_len + int(input_lens[i]) + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + )) + return requests + + +# ----------------------------------------------------------------------------- +# ShareGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ShareGPTDataset(BenchmarkDataset): + """ + Implements the ShareGPT dataset. Loads data from a JSON file and generates + sample requests based on conversation turns. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + with open(self.dataset_path, encoding="utf-8") as f: + self.data = json.load(f) + # Filter entries with at least two conversation turns. + self.data = [ + entry for entry in self.data + if "conversations" in entry and len(entry["conversations"]) >= 2 + ] + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + samples: list = [] + for entry in self.data: + if len(samples) >= num_requests: + break + prompt, completion = ( + entry["conversations"][0]["value"], + entry["conversations"][1]["value"], + ) + + lora_request, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + new_output_len = (len(completion_ids) + if output_len is None else output_len) + if not is_valid_sequence(prompt_len, + new_output_len, + skip_min_output_len_check=output_len + is not None): + continue + if enable_multimodal_chat: + prompt = self.apply_multimodal_chat_transformation( + prompt, None) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=new_output_len, + lora_request=lora_request, + )) + self.maybe_oversample_requests(samples, num_requests) + return samples + + +def add_dataset_parser(parser: FlexibleArgumentParser): + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="random", + choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], + help="Name of the dataset to benchmark on.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) + + # group for dataset specific arguments + custom_group = parser.add_argument_group("custom dataset options") + custom_group.add_argument( + "--custom-output-len", + type=int, + default=256, + help= + "Number of output tokens per request, used only for custom dataset.", + ) + custom_group.add_argument( + "--custom-skip-chat-template", + action="store_true", + help= + "Skip applying chat template to prompt, used only for custom dataset.", + ) + + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.", + ) + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for random sampling. Must be in the range [0, 1) to define " + "a symmetric sampling range" + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help=("Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]."), + ) + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + +def get_samples(args, tokenizer) -> list[SampleRequest]: + if args.dataset_name == "custom": + dataset = CustomDataset(dataset_path=args.dataset_path) + input_requests = dataset.sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.custom_output_len, + skip_chat_template=args.custom_skip_chat_template, + ) + + elif args.dataset_name == "sonnet": + dataset = SonnetDataset(dataset_path=args.dataset_path) + # For the "sonnet" dataset, formatting depends on the backend. + if args.endpoint_type == "openai-chat": + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) + else: + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) + + elif args.dataset_name == "hf": + # all following datasets are implemented from the + # HuggingFaceDataset base class + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_class = VisionArenaDataset + args.hf_split = "train" + args.hf_subset = None + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_class = InstructCoderDataset + args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ConversationDataset + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_class = AIMODataset + args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" + elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: + dataset_class = ASRDataset + args.hf_split = "train" + else: + supported_datasets = set([ + dataset_name for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ]) + raise ValueError( + f"Unsupported dataset path: {args.dataset_path}. " + "Huggingface dataset only supports dataset_path" + f" from one of following: {supported_datasets}. " + "Please consider contributing if you would " + "like to add support for additional dataset formats.") + + if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [ + "openai-chat", + "openai-audio", + ]: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend.") + input_requests = dataset_class( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + random_seed=args.seed, + ).sample( + num_requests=args.num_prompts, + tokenizer=tokenizer, + output_len=args.hf_output_len, + ) + + else: + # For datasets that follow a similar structure, use a mapping. + dataset_mapping = { + "sharegpt": + lambda: ShareGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": + lambda: BurstGPTDataset(random_seed=args.seed, + dataset_path=args.dataset_path). + sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": + lambda: RandomDataset(random_seed=args.seed, + dataset_path=args.dataset_path).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + ), + } + + try: + input_requests = dataset_mapping[args.dataset_name]() + except KeyError as err: + raise ValueError(f"Unknown dataset: {args.dataset_name}") from err + + return input_requests + + +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset.") + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Sonnet Dataset Implementation +# ----------------------------------------------------------------------------- + + +class SonnetDataset(BenchmarkDataset): + """ + Simplified implementation of the Sonnet dataset. Loads poem lines from a + text file and generates sample requests. Default values here copied from + `benchmark_serving.py` for the sonnet dataset. + """ + + DEFAULT_PREFIX_LEN = 200 + DEFAULT_INPUT_LEN = 550 + DEFAULT_OUTPUT_LEN = 150 + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if not self.dataset_path: + raise ValueError("dataset_path must be provided.") + with open(self.dataset_path, encoding="utf-8") as f: + self.data = f.readlines() + + def sample( + self, + tokenizer, + num_requests: int, + prefix_len: int = DEFAULT_PREFIX_LEN, + input_len: int = DEFAULT_INPUT_LEN, + output_len: int = DEFAULT_OUTPUT_LEN, + return_prompt_formatted: bool = False, + **kwargs, + ) -> list: + # Calculate average token length for a poem line. + tokenized_lines = [tokenizer(line).input_ids for line in self.data] + avg_len = sum(len(tokens) + for tokens in tokenized_lines) / len(tokenized_lines) + + # Build the base prompt. + base_prompt = "Pick as many lines as you can from these poem lines:\n" + base_msg = [{"role": "user", "content": base_prompt}] + base_fmt = tokenizer.apply_chat_template(base_msg, + add_generation_prompt=True, + tokenize=False) + base_offset = len(tokenizer(base_fmt).input_ids) + if input_len <= base_offset: + raise ValueError( + f"'input_len' must be higher than the base prompt length " + f"({base_offset}).") + + # Determine how many poem lines to use. + num_input_lines = round((input_len - base_offset) / avg_len) + num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0) + prefix_lines = self.data[:num_prefix_lines] + + samples = [] + while len(samples) < num_requests: + extra_lines = random.choices(self.data, + k=num_input_lines - num_prefix_lines) + prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" + msg = [{"role": "user", "content": prompt}] + prompt_formatted = tokenizer.apply_chat_template( + msg, add_generation_prompt=True, tokenize=False) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + if prompt_len <= input_len: + samples.append( + SampleRequest( + prompt=prompt_formatted + if return_prompt_formatted else prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + return samples + + +# ----------------------------------------------------------------------------- +# BurstGPT Dataset Implementation +# ----------------------------------------------------------------------------- + + +class BurstGPTDataset(BenchmarkDataset): + """ + Implements the BurstGPT dataset. Loads data from a CSV file and generates + sample requests based on synthetic prompt generation. Only rows with Model + "GPT-4" and positive response tokens are used. + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self, ): + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + df = pd.read_csv(self.dataset_path) + # Filter to keep only GPT-4 rows. + gpt4_df = df[df["Model"] == "GPT-4"] + # Remove failed requests (where Response tokens is 0 or less). + gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0] + # Sample the desired number of rows. + self.data = gpt4_df + + def _sample_loaded_data(self, num_requests: int) -> list: + if num_requests <= len(self.data): + data = self.data.sample(n=num_requests, + random_state=self.random_seed) + else: + data = self.data.sample( + n=num_requests, + random_state=self.random_seed, + replace=True, + ) + # Convert the dataframe to a list of lists. + return data.values.tolist() + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + max_loras: Optional[int] = None, + lora_path: Optional[str] = None, + **kwargs, + ) -> list[SampleRequest]: + samples = [] + data = self._sample_loaded_data(num_requests=num_requests) + for i in range(num_requests): + input_len = int(data[i][2]) + output_len = int(data[i][3]) + lora_req, tokenizer = self.get_random_lora_request( + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + vocab_size = tokenizer.vocab_size + # Generate a synthetic prompt: a list of token IDs computed as (i + + # j) modulo vocab_size. + token_ids = [(i + j) % vocab_size for j in range(input_len)] + prompt = tokenizer.decode(token_ids) + samples.append( + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=output_len, + lora_request=lora_req, + )) + return samples + + +# ----------------------------------------------------------------------------- +# HuggingFace Dataset Base Implementation +# ----------------------------------------------------------------------------- +class HuggingFaceDataset(BenchmarkDataset): + """Base class for datasets hosted on HuggingFace.""" + + SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set() + + def __init__( + self, + dataset_path: str, + dataset_split: str, + dataset_subset: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(dataset_path=dataset_path, **kwargs) + + self.dataset_split = dataset_split + self.dataset_subset = dataset_subset + self.load_data() + + def load_data(self) -> None: + """Load data from HuggingFace datasets.""" + self.data = load_dataset( + self.dataset_path, + name=self.dataset_subset, + split=self.dataset_split, + streaming=True, + ) + self.data = self.data.shuffle(seed=self.random_seed) + + +# ----------------------------------------------------------------------------- +# Conversation Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ConversationDataset(HuggingFaceDataset): + """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { + 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + } + IS_MULTIMODAL = True + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len): + continue + mm_content = process_image( + item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Vision Arena Dataset Implementation +# ----------------------------------------------------------------------------- + + +class VisionArenaDataset(HuggingFaceDataset): + """ + Vision Arena Dataset. + """ + + DEFAULT_OUTPUT_LEN = 128 + SUPPORTED_DATASET_PATHS = { + "lmarena-ai/VisionArena-Chat": + lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": + lambda x: x["turns"][0][0]["content"] + } + IS_MULTIMODAL = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) + if parser_fn is None: + raise ValueError( + f"Unsupported dataset path: {self.dataset_path}") + prompt = parser_fn(item) + mm_content = process_image(item["images"][0]) + prompt_len = len(tokenizer(prompt).input_ids) + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len + prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Instruct Coder Dataset Implementation +# ----------------------------------------------------------------------------- + + +class InstructCoderDataset(HuggingFaceDataset): + """ + InstructCoder Dataset. + https://huggingface.co/datasets/likaixin/InstructCoder + + InstructCoder is the dataset designed for general code editing. It consists + of 114,239 instruction-input-output triplets, and covers multiple distinct + code editing scenario. + """ + + DEFAULT_OUTPUT_LEN = 200 # this is the average default output length + SUPPORTED_DATASET_PATHS = { + "likaixin/InstructCoder", + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = f"{item['input']}\n\n{item['instruction']} Just output \ + the code, do not include any explanation." + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0] + + # apply template + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# AIMO Dataset Implementation +# ----------------------------------------------------------------------------- + + +class AIMODataset(HuggingFaceDataset): + """ + Dataset class for processing a AIMO dataset with reasoning questions. + """ + SUPPORTED_DATASET_PATHS = { + "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT" + } + + def sample(self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs) -> list: + sampled_requests = [] + dynamic_output = output_len is None + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt, completion = item['problem'], item["solution"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, + completion_len, + max_prompt_len=2048, + max_total_len=32000): + continue + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=None, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, + original_start_marker: str = "<|editable_region_start|>") -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get( + self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids), + )) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + +# ----------------------------------------------------------------------------- +# ASR Dataset Implementation +# ----------------------------------------------------------------------------- + + +class ASRDataset(HuggingFaceDataset): + """ + Dataset class for processing a ASR dataset for transcription. + Tested on the following set: + + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | Dataset | Domain | Speaking Style | hf-subset | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + | TED-LIUM | TED talks | Oratory | release1, release2, release3| + | | | | release3-speaker-adaptation | + | VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... | + | LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" | + | GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test | + | SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test | + | AMI | Meetings | Spontaneous | ihm, sdm | + +----------------+----------------------------------------+--------------------------+-----------------------------+ + + """ # noqa: E501 + + SUPPORTED_DATASET_PATHS = { + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", + } + + DEFAULT_OUTPUT_LEN = 128 + IS_MULTIMODAL = True + + # TODO Whisper-specific. Abstract interface when more models are supported. + TRANSCRIPTION_PREAMBLE = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>") + skip_long_audios: bool = True + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: + output_len = (output_len + if output_len is not None else self.DEFAULT_OUTPUT_LEN) + prompt = ASRDataset.TRANSCRIPTION_PREAMBLE + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests = [] + skipped = 0 + for item in self.data: + if len(sampled_requests) >= num_requests: + break + audio = item["audio"] + y, sr = audio["array"], audio["sampling_rate"] + duration_s = librosa.get_duration(y=y, sr=sr) + # Whisper max supported duration + if self.skip_long_audios and duration_s > 30: + skipped += 1 + continue + + mm_content = {"audio": (y, sr)} + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + )) + if skipped: + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/vllm/benchmarks/endpoint_request_func.py b/vllm/benchmarks/endpoint_request_func.py new file mode 100644 index 0000000..60ae520 --- /dev/null +++ b/vllm/benchmarks/endpoint_request_func.py @@ -0,0 +1,393 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""The request function for API endpoints.""" + +import io +import json +import os +import sys +import time +import traceback +from dataclasses import dataclass, field +from typing import Optional + +import aiohttp +from tqdm.asyncio import tqdm + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +@dataclass +class RequestFuncInput: + """The input for the request function.""" + prompt: str + api_url: str + prompt_len: int + output_len: int + model: str + model_name: Optional[str] = None + logprobs: Optional[int] = None + extra_body: Optional[dict] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False + language: Optional[str] = None + + +@dataclass +class RequestFuncOutput: + """The output of the request function including metrics.""" + generated_text: str = "" + success: bool = False + latency: float = 0.0 + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + itl: list[float] = field( + default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + error: str = "" + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """The async request function for the OpenAI Completions API. + + Args: + request_func_input: The input for the request function. + pbar: The progress bar to display the progress. + + Returns: + The output of the request function. + """ + api_url = request_func_input.api_url + assert api_url.endswith( + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model_name \ + if request_func_input.model_name else request_func_input.model, + "prompt": request_func_input.prompt, + "temperature": 0.0, + "repetition_penalty": 1.0, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + first_chunk_received = False + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if chunk_bytes.startswith(":"): + continue + + chunk = chunk_bytes.removeprefix("data: ") + + if chunk != "[DONE]": + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!") + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'.") + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "messages": [ + { + "role": "user", + "content": content + }, + ], + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + True, + "stream_options": { + "include_usage": True, + }, + } + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk_bytes = chunk_bytes.decode("utf-8") + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if chunk_bytes.startswith(":"): + continue + + chunk = chunk_bytes.removeprefix("data: ") + + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - + most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + + api_url = request_func_input.api_url + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' ") + "or `translations`." + + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": + request_func_input.model_name + if request_func_input.model_name else request_func_input.model, + "temperature": + 0.0, + "max_completion_tokens": + request_func_input.output_len, + "stream": + True, + "language": + "en", + # Flattened due to multipart/form-data + "stream_include_usage": + True, + "stream_continuous_usage_stats": + True, + } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: + form = aiohttp.FormData() + form.add_field("file", f, content_type="audio/wav") + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, + data=form, + headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = chunk_bytes.decode("utf-8").removeprefix( + "data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get( + "content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +# TODO: Add more request functions for different API protocols. +ASYNC_REQUEST_FUNCS = { + "vllm": async_request_openai_completions, + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, +} + +OPENAI_COMPATIBLE_BACKENDS = [ + k for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, + async_request_openai_chat_completions) +] diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py new file mode 100644 index 0000000..5c6124d --- /dev/null +++ b/vllm/benchmarks/latency.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the latency of processing a single batch of requests.""" + +import argparse +import dataclasses +import json +import os +import time +from typing import Any, Optional + +import numpy as np +from tqdm import tqdm + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.sampling_params import BeamSearchParams + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={"latency": results["latencies"]}, + extra_info={k: results[k] + for k in ["avg_latency", "percentiles"]}) + if pt_records: + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--input-len", type=int, default=32) + parser.add_argument("--output-len", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument( + "--n", + type=int, + default=1, + help="Number of generated sequences per prompt.", + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--num-iters-warmup", + type=int, + default=10, + help="Number of iterations to run for warmup.", + ) + parser.add_argument("--num-iters", + type=int, + default=30, + help="Number of iterations to run.") + parser.add_argument( + "--profile", + action="store_true", + help="profile the generation process of a single batch", + ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Path to save the latency results in JSON format.", + ) + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)"), + ) + + parser = EngineArgs.add_cli_args(parser) + # V1 enables prefix caching by default which skews the latency + # numbers. We need to disable prefix caching by default. + parser.set_defaults(enable_prefix_caching=False) + + +def main(args: argparse.Namespace): + if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: + raise OSError( + "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " + "Please set it to a valid path to use torch profiler.") + engine_args = EngineArgs.from_cli_args(args) + + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. + llm = LLM(**dataclasses.asdict(engine_args)) + assert llm.llm_engine.model_config.max_model_len >= ( + args.input_len + + args.output_len), ("Please ensure that max_model_len is greater than" + " the sum of input_len and output_len.") + + sampling_params = SamplingParams( + n=args.n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_prompts: list[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def llm_generate(): + if not args.use_beam_search: + llm.generate(dummy_prompts, + sampling_params=sampling_params, + use_tqdm=False) + else: + llm.beam_search( + dummy_prompts, + BeamSearchParams( + beam_width=args.n, + max_tokens=args.output_len, + ignore_eos=True, + ), + ) + + def run_to_completion(profile_dir: Optional[str] = None): + if profile_dir: + llm.start_profile() + llm_generate() + llm.stop_profile() + else: + start_time = time.perf_counter() + llm_generate() + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + print("Warming up...") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_dir=None) + + if args.profile: + profile_dir = envs.VLLM_TORCH_PROFILER_DIR + print(f"Profiling (results will be saved to '{profile_dir}')...") + run_to_completion(profile_dir=profile_dir) + return + + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f"Avg latency: {np.mean(latencies)} seconds") + for percentage, percentile in zip(percentages, percentiles): + print(f"{percentage}% percentile latency: {percentile} seconds") + + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py new file mode 100644 index 0000000..8b16fea --- /dev/null +++ b/vllm/benchmarks/serve.py @@ -0,0 +1,1063 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +r"""Benchmark online serving throughput. + +On the server side, run one of the following commands +to launch the vLLM OpenAI API server: + vllm serve + +On the client side, run: + vllm bench serve \ + --endpoint-type \ + --label \ + --model \ + --dataset-name \ + --request-rate \ + --num-prompts +""" +import argparse +import asyncio +import gc +import json +import os +import random +import time +import warnings +from collections.abc import AsyncGenerator, Iterable +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Literal, Optional + +import numpy as np +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser, + get_samples) +from vllm.benchmarks.endpoint_request_func import (ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.transformers_utils.tokenizer import get_tokenizer + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +def _get_current_request_rate( + ramp_up_strategy: Optional[Literal["linear", "exponential"]], + ramp_up_start_rps: Optional[int], + ramp_up_end_rps: Optional[int], + request_index: int, + total_requests: int, + request_rate: float, +) -> float: + if (ramp_up_strategy and ramp_up_start_rps is not None + and ramp_up_end_rps is not None): + progress = request_index / max(total_requests - 1, 1) + if ramp_up_strategy == "linear": + increase = (ramp_up_end_rps - ramp_up_start_rps) * progress + return ramp_up_start_rps + increase + elif ramp_up_strategy == "exponential": + ratio = ramp_up_end_rps / ramp_up_start_rps + return ramp_up_start_rps * (ratio**progress) + else: + raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") + return request_rate + + +async def get_request( + input_requests: list[SampleRequest], + request_rate: float, + burstiness: float = 1.0, + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, +) -> AsyncGenerator[tuple[SampleRequest, float], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness and OPTIONAL ramp-up strategy. + + Args: + input_requests: + A list of input requests, each represented as a SampleRequest. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + ramp_up_strategy (optional): + The ramp-up strategy. Can be "linear" or "exponential". + If None, uses constant request rate (specified by request_rate). + ramp_up_start_rps (optional): + The starting request rate for ramp-up. + ramp_up_end_rps (optional): + The ending request rate for ramp-up. + """ + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + # Convert to list to get length for ramp-up calculations + if isinstance(input_requests, Iterable) and not isinstance( + input_requests, list): + input_requests = list(input_requests) + + total_requests = len(input_requests) + request_index = 0 + + for request in input_requests: + current_request_rate = _get_current_request_rate(ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate) + + yield request, current_request_rate + + request_index += 1 + + if current_request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + theta = 1.0 / (current_request_rate * burstiness) + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentiles: list[float], + goodput_config_dict: dict[str, float], +) -> tuple[BenchmarkMetrics, list[int]]: + """Calculate the metrics for the benchmark. + + Args: + input_requests: The input requests. + outputs: The outputs of the requests. + dur_s: The duration of the benchmark. + tokenizer: The tokenizer to use. + selected_percentiles: The percentiles to select. + goodput_config_dict: The goodput configuration. + + Returns: + A tuple of the benchmark metrics and the actual output lengths. + """ + actual_output_lens: list[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: list[float] = [] + tpots: list[float] = [] + all_tpots: list[float] = [] + ttfts: list[float] = [] + e2els: list[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + output_len = outputs[i].output_tokens + + if not output_len: + # We use the tokenizer to count the number of output tokens + # for some serving backends instead of looking at + # len(outputs[i].itl) since multiple output tokens may be + # bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + latency_minus_ttft = outputs[i].latency - outputs[i].ttft + tpot = latency_minus_ttft / (output_len - 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if goodput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in goodput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(goodput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in goodput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(goodput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in goodput_config_dict: + valid_metrics.append(e2els) + slo_values.append(goodput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by the endpoint + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + endpoint_type: str, + api_url: str, + base_url: str, + model_id: str, + model_name: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: list[SampleRequest], + logprobs: Optional[int], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: list[str], + selected_percentiles: list[float], + ignore_eos: bool, + goodput_config_dict: dict[str, float], + max_concurrency: Optional[int], + lora_modules: Optional[Iterable[str]], + extra_body: Optional[dict], + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, +): + if endpoint_type in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[endpoint_type] + else: + raise ValueError(f"Unknown endpoint_type: {endpoint_type}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) + + assert test_mm_content is None or isinstance(test_mm_content, dict) + test_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if lora_modules: + # For each input request, choose a LoRA module at random. + lora_modules = iter( + [random.choice(lora_modules) for _ in range(len(input_requests))]) + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + distribution = ("Poisson process" if burstiness == 1.0 + else "Gamma distribution") + + if ramp_up_strategy is not None: + print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") + print(f"Will increase RPS from {ramp_up_start_rps} to " + f"{ramp_up_end_rps} RPS over the duration of the benchmark.") + else: + print(f"Traffic request rate: {request_rate}") + + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: list[asyncio.Task] = [] + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append({ + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + }) + + async for request, current_request_rate in get_request( + input_requests, request_rate, burstiness, ramp_up_strategy, + ramp_up_start_rps, ramp_up_end_rps): + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({ + "rps": rps_val, + "timestamp": timestamp + }) + last_int_rps = current_int_rps + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) + req_model_id, req_model_name = model_id, model_name + if lora_modules: + req_lora_module = next(lora_modules) + req_model_id, req_model_name = req_lora_module, req_lora_module + + request_func_input = RequestFuncInput(model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + goodput_config_dict=goodput_config_dict, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + if goodput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "request_goodput": + metrics.request_goodput if goodput_config_dict else None, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + + if rps_change_events: + result["rps_change_events"] = rps_change_events + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + return result + + +def check_goodput_args(args): + # Check and parse goodput arguments + goodput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + goodput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in goodput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return goodput_config_dict + + +def parse_goodput(slo_pairs): + goodput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + goodput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return goodput_config_dict + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any], + file_name: str) -> None: + metrics = [ + "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", + "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", + "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + ] + # These raw data might be useful, but they are rather big. They can be added + # later if needed + ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={k: [results[k]] + for k in metrics}, + extra_info={ + k: results[k] + for k in results if k not in metrics and k not in ignored_metrics + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def add_cli_args(parser: argparse.ArgumentParser): + add_dataset_parser(parser) + parser.add_argument( + "--endpoint-type", + type=str, + default="openai", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--label", + type=str, + default=None, + help="The label (prefix) of the benchmark results. If not specified, " + "the endpoint type will be used as the label.", + ) + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + # Use 127.0.0.1 here instead of localhost to force the use of ipv4 + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--save-result", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--save-detailed", + action="store_true", + help="When saving the results, whether to include per request " + "information such as response, error, ttfs, tpots, etc.", + ) + parser.add_argument( + "--append-result", + action="store_true", + help="Append the benchmark result to the existing json file.", + ) + parser.add_argument( + "--metadata", + metavar="KEY=VALUE", + nargs="*", + help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) " + "for metadata of this run to be saved in the result JSON file " + "for record keeping purposes.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{label}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" # noqa + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-separated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-separated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\"." + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).", + ) + + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral', 'custom'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') + + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ") + + parser.add_argument("--lora-modules", + nargs='+', + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.") + + parser.add_argument( + "--ramp-up-strategy", + type=str, + default=None, + choices=["linear", "exponential"], + help="The ramp-up strategy. This would be used to " + "ramp up the request rate from initial RPS to final " + "RPS rate (specified by --ramp-up-start-rps and " + "--ramp-up-end-rps.) over the duration of the benchmark." + ) + parser.add_argument( + "--ramp-up-start-rps", + type=int, + default=None, + help="The starting request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ramp-up-end-rps", + type=int, + default=None, + help="The ending request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + # Validate ramp-up arguments + if args.ramp_up_strategy is not None: + if args.request_rate != float("inf"): + raise ValueError( + "When using ramp-up, do not specify --request-rate. " + "The request rate will be controlled by ramp-up parameters. " + "Please remove the --request-rate argument." + ) + if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: + raise ValueError( + "When using --ramp-up-strategy, both --ramp-up-start-rps and " + "--ramp-up-end-rps must be specified" + ) + if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: + raise ValueError("Ramp-up start and end RPS must be non-negative") + if args.ramp_up_start_rps > args.ramp_up_end_rps: + raise ValueError("Ramp-up start RPS must be less than end RPS") + if (args.ramp_up_strategy == "exponential" + and args.ramp_up_start_rps == 0): + raise ValueError( + "For exponential ramp-up, the start RPS cannot be 0.") + + endpoint_type = args.endpoint_type + label = args.label + model_id = args.model + model_name = args.served_model_name + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code) + + if args.dataset_name is None: + raise ValueError( + "Please specify '--dataset-name' and the corresponding " + "'--dataset-path' if required.") + + # Load the dataset. + input_requests = get_samples(args, tokenizer) + goodput_config_dict = check_goodput_args(args) + + # Collect the sampling parameters. + sampling_params = { + k: v + for k, v in { + "top_p": args.top_p, + "top_k": args.top_k, + "min_p": args.min_p, + "temperature": args.temperature, + }.items() if v is not None + } + + # Sampling parameters are only supported by openai-compatible backend. + if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: + raise ValueError("Sampling parameters are only supported by " + "openai-compatible backends.") + + if "temperature" not in sampling_params: + sampling_params["temperature"] = 0.0 # Default to greedy decoding. + + # Avoid GC processing "static" data - reduce pause times. + gc.collect() + gc.freeze() + + benchmark_result = asyncio.run( + benchmark( + endpoint_type=args.endpoint_type, + api_url=api_url, + base_url=base_url, + model_id=model_id, + model_name=model_name, + tokenizer=tokenizer, + input_requests=input_requests, + logprobs=args.logprobs, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + goodput_config_dict=goodput_config_dict, + max_concurrency=args.max_concurrency, + lora_modules=args.lora_modules, + extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, + )) + + # Save config and results to json + if args.save_result or args.append_result: + result_json: dict[str, Any] = {} + + # Setup + current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") + result_json["date"] = current_dt + result_json["endpoint_type"] = args.endpoint_type + result_json["label"] = label + result_json["model_id"] = model_id + result_json["tokenizer_id"] = tokenizer_id + result_json["num_prompts"] = args.num_prompts + + # Metadata + if args.metadata: + for item in args.metadata: + if "=" in item: + kvstring = item.split("=") + result_json[kvstring[0].strip()] = kvstring[1].strip() + else: + raise ValueError( + "Invalid metadata format. Please use KEY=VALUE format." + ) + + # Traffic + result_json["request_rate"] = (args.request_rate if args.request_rate + < float("inf") else "inf") + result_json["burstiness"] = args.burstiness + result_json["max_concurrency"] = args.max_concurrency + + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + + # Merge with benchmark result + result_json = {**result_json, **benchmark_result} + + if not args.save_detailed: + # Remove fields with too many data points + for field in [ + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", + ]: + if field in result_json: + del result_json[field] + if field in benchmark_result: + del benchmark_result[field] + + # Save to file + base_model_id = model_id.split("/")[-1] + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + label = label or endpoint_type + if args.ramp_up_strategy is not None: + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + if args.result_filename: + file_name = args.result_filename + if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) + file_name = os.path.join(args.result_dir, file_name) + with open(file_name, + mode="a+" if args.append_result else "w", + encoding="utf-8") as outfile: + # Append a newline. + if args.append_result and outfile.tell() != 0: + outfile.write("\n") + json.dump(result_json, outfile) + save_to_pytorch_benchmark_format(args, result_json, file_name) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py new file mode 100644 index 0000000..af2ca96 --- /dev/null +++ b/vllm/benchmarks/throughput.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +import warnings +from typing import Any, Optional, Union + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, RandomDataset, + SampleRequest, ShareGPTDataset, + SonnetDataset, VisionArenaDataset) +from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format, + write_to_json) +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.utils import merge_async_iterators + + +def run_vllm( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, Optional[list[RequestOutput]]]: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests: Optional[list[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) + end = time.perf_counter() + else: + assert lora_requests is None, "BeamSearch API does not support LoRA" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0].expected_output_len + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() + return end - start, outputs + + +def run_vllm_chat( + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + """ + Run vLLM chat benchmark. This function is recommended ONLY for benchmarking + multimodal models as it properly handles multimodal inputs and chat + formatting. For non-multimodal models, use run_vllm() instead. + """ + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + assert all( + llm.llm_engine.model_config.max_model_len >= ( + request.prompt_len + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests.") + + prompts = [] + sampling_params: list[SamplingParams] = [] + for request in requests: + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + start = time.perf_counter() + outputs = llm.chat(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + return end - start, outputs + + +async def run_vllm_async( + requests: list[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + disable_detokenize: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + model_config = await llm.get_model_config() + assert all( + model_config.max_model_len >= (request.prompt_len + + request.expected_output_len) + for request in requests), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests.") + + # Add the requests to the engine. + prompts: list[Union[TextPrompt, TokensPrompt]] = [] + sampling_params: list[SamplingParams] = [] + lora_requests: list[Optional[LoRARequest]] = [] + for request in requests: + prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + detokenize=not disable_detokenize, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: list[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + max_batch_size: int, + trust_remote_code: bool, + disable_detokenize: bool = False, +) -> float: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: list[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt = requests[i].prompt + prompt_len = requests[i].prompt_len + output_len = requests[i].expected_output_len + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + next_prompt_len = requests[i + 1].prompt_len + next_output_len = requests[i + 1].expected_output_len + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=True, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + if not disable_detokenize: + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def save_to_pytorch_benchmark_format(args: argparse.Namespace, + results: dict[str, Any]) -> None: + pt_records = convert_to_pytorch_benchmark_format( + args=args, + metrics={ + "requests_per_second": [results["requests_per_second"]], + "tokens_per_second": [results["tokens_per_second"]], + }, + extra_info={ + k: results[k] + for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }) + if pt_records: + # Don't use json suffix here as we don't want CI to pick it up + pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" + write_to_json(pt_file, pt_records) + + +def get_requests(args, tokenizer): + # Common parameters for all dataset types. + common_kwargs = { + "dataset_path": args.dataset_path, + "random_seed": args.seed, + } + sample_kwargs = { + "tokenizer": tokenizer, + "lora_path": args.lora_path, + "max_loras": args.max_loras, + "num_requests": args.num_prompts, + "input_len": args.input_len, + "output_len": args.output_len, + } + + if args.dataset_path is None or args.dataset_name == "random": + sample_kwargs["range_ratio"] = args.random_range_ratio + sample_kwargs["prefix_len"] = args.prefix_len + dataset_cls = RandomDataset + elif args.dataset_name == "sharegpt": + dataset_cls = ShareGPTDataset + if args.backend == "vllm-chat": + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_name == "sonnet": + assert tokenizer.chat_template or tokenizer.default_chat_template, ( + "Tokenizer/model must have chat template for sonnet dataset.") + dataset_cls = SonnetDataset + sample_kwargs["prefix_len"] = args.prefix_len + sample_kwargs["return_prompt_formatted"] = True + elif args.dataset_name == "burstgpt": + dataset_cls = BurstGPTDataset + elif args.dataset_name == "hf": + if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = VisionArenaDataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = InstructCoderDataset + common_kwargs['dataset_split'] = "train" + elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = ConversationDataset + common_kwargs['dataset_subset'] = args.hf_subset + common_kwargs['dataset_split'] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True + elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: + dataset_cls = AIMODataset + common_kwargs['dataset_subset'] = None + common_kwargs['dataset_split'] = "train" + else: + raise ValueError(f"Unknown dataset name: {args.dataset_name}") + # Remove None values + sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None} + return dataset_cls(**common_kwargs).sample(**sample_kwargs) + + +def validate_args(args): + """ + Validate command-line arguments. + """ + + # === Deprecation and Defaulting === + if args.dataset is not None: + warnings.warn( + "The '--dataset' argument will be deprecated in the next release. " + "Please use '--dataset-name' and '--dataset-path' instead.", + stacklevel=2) + args.dataset_path = args.dataset + + if not getattr(args, "tokenizer", None): + args.tokenizer = args.model + + # === Backend Validation === + valid_backends = {"vllm", "hf", "mii", "vllm-chat"} + if args.backend not in valid_backends: + raise ValueError(f"Unsupported backend: {args.backend}") + + # === Dataset Configuration === + if not args.dataset and not args.dataset_path: + print( + "When dataset path is not set, it will default to random dataset") + args.dataset_name = 'random' + if args.input_len is None: + raise ValueError("input_len must be provided for a random dataset") + + # === Dataset Name Specific Checks === + # --hf-subset and --hf-split: only used + # when dataset_name is 'hf' + if args.dataset_name != "hf" and ( + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None): + warnings.warn("--hf-subset and --hf-split will be ignored \ + since --dataset-name is not 'hf'.", + stacklevel=2) + elif args.dataset_name == "hf": + if args.dataset_path in ( + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 + elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS): + assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + else: + raise ValueError( + f"{args.dataset_path} is not supported by hf dataset.") + + # --random-range-ratio: only used when dataset_name is 'random' + if args.dataset_name != 'random' and args.random_range_ratio is not None: + warnings.warn("--random-range-ratio will be ignored since \ + --dataset-name is not 'random'.", + stacklevel=2) + + # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not + # set. + if args.dataset_name not in {"random", "sonnet", None + } and args.prefix_len is not None: + warnings.warn("--prefix-len will be ignored since --dataset-name\ + is not 'random', 'sonnet', or not set.", + stacklevel=2) + + # === LoRA Settings === + if getattr(args, "enable_lora", False) and args.backend != "vllm": + raise ValueError( + "LoRA benchmarking is only supported for vLLM backend") + if getattr(args, "enable_lora", False) and args.lora_path is None: + raise ValueError("LoRA path must be provided when enable_lora is True") + + # === Backend-specific Validations === + if args.backend == "hf" and args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend") + if args.backend != "hf" and args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + + if args.backend in {"hf", "mii"} and getattr(args, "quantization", + None) is not None: + raise ValueError("Quantization is only for vLLM backend.") + + if args.backend == "mii" and args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.backend == "mii" and args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.backend == "mii" and args.tokenizer != args.model: + raise ValueError( + "Tokenizer must be the same as the model for MII backend.") + + +def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm") + parser.add_argument( + "--dataset-name", + type=str, + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], + help="Name of the dataset to benchmark on.", + default="sharegpt") + parser.add_argument( + "--dataset", + type=str, + default=None, + help="Path to the ShareGPT dataset, will be deprecated in\ + the next release. The dataset is expected to " + "be a json in form of list[dict[..., conversations: " + "list[dict[..., value: ]]]]") + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument( + "--disable-detokenize", + action="store_true", + help=("Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)")) + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before the random " + "context in a request (default: 0).", + ) + # random dataset + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range ratio for sampling input/output length, " + "used only for RandomDataset. Must be in the range [0, 1) to define " + "a symmetric sampling range " + "[length * (1 - range_ratio), length * (1 + range_ratio)].", + ) + + # hf dtaset + parser.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + parser.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + +def main(args: argparse.Namespace): + if args.tokenizer is None: + args.tokenizer = args.model + validate_args(args) + if args.seed is None: + args.seed = 0 + random.seed(args.seed) + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = get_requests(args, tokenizer) + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + request_outputs: Optional[list[RequestOutput]] = None + if args.backend == "vllm": + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + args.disable_detokenize, + )) + else: + elapsed_time, request_outputs = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.hf_max_batch_size, args.trust_remote_code, + args.disable_detokenize) + elif args.backend == "vllm-chat": + elapsed_time, request_outputs = run_vllm_chat( + requests, args.n, EngineArgs.from_cli_args(args), + args.disable_detokenize) + else: + raise ValueError(f"Unknown backend: {args.backend}") + + if request_outputs: + # Note: with the vllm and vllm-chat backends, + # we have request_outputs, which we use to count tokens. + total_prompt_tokens = 0 + total_output_tokens = 0 + for ro in request_outputs: + if not isinstance(ro, RequestOutput): + continue + total_prompt_tokens += len( + ro.prompt_token_ids) if ro.prompt_token_ids else 0 + total_output_tokens += sum( + len(o.token_ids) for o in ro.outputs if o) + total_num_tokens = total_prompt_tokens + total_output_tokens + else: + total_num_tokens = sum(r.prompt_len + r.expected_output_len + for r in requests) + total_output_tokens = sum(r.expected_output_len for r in requests) + total_prompt_tokens = total_num_tokens - total_output_tokens + + if is_multi_modal and args.backend != "vllm-chat": + print("\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. + # vllm-chat backend counts the image tokens now + + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print(f"Total num prompt tokens: {total_prompt_tokens}") + print(f"Total num output tokens: {total_output_tokens}") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + save_to_pytorch_benchmark_format(args, results) diff --git a/vllm/benchmarks/utils.py b/vllm/benchmarks/utils.py new file mode 100644 index 0000000..f0bb993 --- /dev/null +++ b/vllm/benchmarks/utils.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import json +import math +import os +from typing import Any + + +def convert_to_pytorch_benchmark_format(args: argparse.Namespace, + metrics: dict[str, list], + extra_info: dict[str, Any]) -> list: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + + tp = record["benchmark"]["extra_info"]["args"].get( + "tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"][ + "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + + records.append(record) + + return records + + +class InfEncoder(json.JSONEncoder): + + def clear_inf(self, o: Any): + if isinstance(o, dict): + return {k: self.clear_inf(v) for k, v in o.items()} + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: list) -> None: + with open(filename, "w") as f: + json.dump(records, f, cls=InfEncoder) diff --git a/vllm/collect_env.py b/vllm/collect_env.py new file mode 100644 index 0000000..64172a9 --- /dev/null +++ b/vllm/collect_env.py @@ -0,0 +1,820 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa +# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py + +import datetime +import locale +import os +import subprocess +import sys +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` +from collections import namedtuple + +import regex as re + +from vllm.envs import environment_variables + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple( + 'SystemEnv', + [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', + 'rocm_version', # vllm specific field + 'neuron_sdk_version', # vllm specific field + 'vllm_version', # vllm specific field + 'vllm_build_flags', # vllm specific field + 'gpu_topo', # vllm specific field + 'env_vars', + ]) + +DEFAULT_CONDA_PATTERNS = { + "torch", + "numpy", + "cudatoolkit", + "soumith", + "mkl", + "magma", + "triton", + "optree", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", +} + +DEFAULT_PIP_PATTERNS = { + "torch", + "numpy", + "mypy", + "flake8", + "triton", + "optree", + "onnx", + "nccl", + "transformers", + "zmq", + "nvidia", + "pynvml", +} + + +def run(command): + """Return (return-code, stdout, stderr).""" + shell = True if type(command) is str else False + p = subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + if command == 'nvidia-smi topo -m': + # don't remove the leading whitespace of `nvidia-smi topo -m` + # because they are meaningful + output = output.rstrip() + else: + output = output.strip() + err = raw_err.decode(enc) + return rc, output, err.strip() + + +def run_and_read_all(run_lambda, command): + """Run command using run_lambda; reads and returns entire output if rc is 0.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Run command using run_lambda, returns the first regex match if it exists.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def run_and_return_first_line(run_lambda, command): + """Run command using run_lambda and returns first line if output is not empty.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split('\n')[0] + + +def get_conda_packages(run_lambda, patterns=None): + if patterns is None: + patterns = DEFAULT_CONDA_PATTERNS + conda = os.environ.get('CONDA_EXE', 'conda') + out = run_and_read_all(run_lambda, "{} list".format(conda)) + if out is None: + return out + + return "\n".join(line for line in out.splitlines() + if not line.startswith("#") and any(name in line + for name in patterns)) + + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', + r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', + r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, + r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr( + torch.version, 'hip') and torch.version.hip is not None): + if TORCH_AVAILABLE and torch.cuda.is_available(): + if torch.version.hip is not None: + prop = torch.cuda.get_device_properties(0) + if hasattr(prop, "gcnArchName"): + gcnArch = " ({})".format(prop.gcnArchName) + else: + gcnArch = "NoGCNArchNameOnOldPyTorch" + else: + gcnArch = "" + return torch.cuda.get_device_name(None) + gcnArch + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', + r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """Return a list of libcudnn.so; it's hard to tell which one is being used.""" + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get('CUDNN_LIBRARY') + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = sorted(files_set) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', + 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', + 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +def get_rocm_version(run_lambda): + """Returns the ROCm version if available, otherwise 'N/A'.""" + return run_and_parse_first_match(run_lambda, 'hipcc --version', + r'HIP version: (\S+)') + + +def get_neuron_sdk_version(run_lambda): + # Adapted from your install script + try: + result = run_lambda(["neuron-ls"]) + return result if result[0] == 0 else 'N/A' + except Exception: + return 'N/A' + + +def get_vllm_version(): + from vllm import __version__, __version_tuple__ + + if __version__ == "dev": + return "N/A (dev)" + version_str = __version_tuple__[-1] + if isinstance(version_str, str) and version_str.startswith('g'): + # it's a dev build + if '.' in version_str: + # it's a dev build containing local changes + git_sha = version_str.split('.')[0][1:] + date = version_str.split('.')[-1][1:] + return f"{__version__} (git sha: {git_sha}, date: {date})" + else: + # it's a dev build without local changes + git_sha = version_str[1:] # type: ignore + return f"{__version__} (git sha: {git_sha})" + return __version__ + + +def summarize_vllm_build_flags(): + # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. + return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( + os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), + 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', + 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled', + ) + + +def get_gpu_topo(run_lambda): + output = None + + if get_platform() == 'linux': + output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if output is None: + output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + + return output + + +# example outputs of CPU infos +# * linux +# Architecture: x86_64 +# CPU op-mode(s): 32-bit, 64-bit +# Address sizes: 46 bits physical, 48 bits virtual +# Byte Order: Little Endian +# CPU(s): 128 +# On-line CPU(s) list: 0-127 +# Vendor ID: GenuineIntel +# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# CPU family: 6 +# Model: 106 +# Thread(s) per core: 2 +# Core(s) per socket: 32 +# Socket(s): 2 +# Stepping: 6 +# BogoMIPS: 5799.78 +# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr +# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl +# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 +# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand +# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced +# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap +# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 +# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq +# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities +# Virtualization features: +# Hypervisor vendor: KVM +# Virtualization type: full +# Caches (sum of all): +# L1d: 3 MiB (64 instances) +# L1i: 2 MiB (64 instances) +# L2: 80 MiB (64 instances) +# L3: 108 MiB (2 instances) +# NUMA: +# NUMA node(s): 2 +# NUMA node0 CPU(s): 0-31,64-95 +# NUMA node1 CPU(s): 32-63,96-127 +# Vulnerabilities: +# Itlb multihit: Not affected +# L1tf: Not affected +# Mds: Not affected +# Meltdown: Not affected +# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown +# Retbleed: Not affected +# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp +# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization +# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence +# Srbds: Not affected +# Tsx async abort: Not affected +# * win32 +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU0 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 +# +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU1 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 + + +def get_cpu_info(run_lambda): + rc, out, err = 0, '', '' + if get_platform() == 'linux': + rc, out, err = run_lambda('lscpu') + elif get_platform() == 'win32': + rc, out, err = run_lambda( + 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE' + ) + elif get_platform() == 'darwin': + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = 'None' + if rc == 0: + cpu_info = out + else: + cpu_info = err + return cpu_info + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', + r'(.*)') + + +def get_windows_version(run_lambda): + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') + findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + return run_and_read_all( + run_lambda, + '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', + r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'macOS {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + return platform.platform() + + +def get_libc_version(): + import platform + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) + + +def get_pip_packages(run_lambda, patterns=None): + """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" + if patterns is None: + patterns = DEFAULT_PIP_PATTERNS + + def run_with_pip(): + try: + import importlib.util + pip_spec = importlib.util.find_spec('pip') + pip_available = pip_spec is not None + except ImportError: + pip_available = False + + if pip_available: + cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] + elif os.environ.get("UV") is not None: + print("uv is set") + cmd = ["uv", "pip", "list", "--format=freeze"] + else: + raise RuntimeError( + "Could not collect pip list output (pip or uv module not available)" + ) + + out = run_and_read_all(run_lambda, cmd) + return "\n".join(line for line in out.splitlines() + if any(name in line for name in patterns)) + + pip_version = 'pip3' if sys.version[0] == '3' else 'pip' + out = run_with_pip() + return pip_version, out + + +def get_cachingallocator_config(): + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + return ca_config + + +def get_cuda_module_loading_config(): + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.init() + config = os.environ.get('CUDA_MODULE_LOADING', '') + return config + else: + return "N/A" + + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + return str( + torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + + +def get_env_vars(): + env_vars = '' + secret_terms = ('secret', 'token', 'api', 'access', 'password') + report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", + "OMP_", "MKL_", "NVIDIA") + for k, v in os.environ.items(): + if any(term in k.lower() for term in secret_terms): + continue + if k in environment_variables: + env_vars = env_vars + "{}={}".format(k, v) + "\n" + if k.startswith(report_prefix): + env_vars = env_vars + "{}={}".format(k, v) + "\n" + + return env_vars + + +def get_env_info(): + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if not hasattr(torch.version, + 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else 'N/A' + + cfg = torch._C._show_config().split('\n') + hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') + miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + + sys_version = sys.version.replace("\n", " ") + + conda_packages = get_conda_packages(run_lambda) + + rocm_version = get_rocm_version(run_lambda) + neuron_sdk_version = get_neuron_sdk_version(run_lambda) + vllm_version = get_vllm_version() + vllm_build_flags = summarize_vllm_build_flags() + gpu_topo = get_gpu_topo(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version='{} ({}-bit runtime)'.format( + sys_version, + sys.maxsize.bit_length() + 1), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + cuda_module_loading=get_cuda_module_loading_config(), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + cpu_info=get_cpu_info(run_lambda), + rocm_version=rocm_version, + neuron_sdk_version=neuron_sdk_version, + vllm_version=vllm_version, + vllm_build_flags=vllm_build_flags, + gpu_topo=gpu_topo, + env_vars=get_env_vars(), + ) + + +env_info_fmt = """ +============================== + System Info +============================== +OS : {os} +GCC version : {gcc_version} +Clang version : {clang_version} +CMake version : {cmake_version} +Libc version : {libc_version} + +============================== + PyTorch Info +============================== +PyTorch version : {torch_version} +Is debug build : {is_debug_build} +CUDA used to build PyTorch : {cuda_compiled_version} +ROCM used to build PyTorch : {hip_compiled_version} + +============================== + Python Environment +============================== +Python version : {python_version} +Python platform : {python_platform} + +============================== + CUDA / GPU Info +============================== +Is CUDA available : {is_cuda_available} +CUDA runtime version : {cuda_runtime_version} +CUDA_MODULE_LOADING set to : {cuda_module_loading} +GPU models and configuration : {nvidia_gpu_models} +Nvidia driver version : {nvidia_driver_version} +cuDNN version : {cudnn_version} +HIP runtime version : {hip_runtime_version} +MIOpen runtime version : {miopen_runtime_version} +Is XNNPACK available : {is_xnnpack_available} + +============================== + CPU Info +============================== +{cpu_info} + +============================== +Versions of relevant libraries +============================== +{pip_packages} +{conda_packages} +""".strip() + +# both the above code and the following code use `strip()` to +# remove leading/trailing whitespaces, so we need to add a newline +# in between to separate the two sections +env_info_fmt += "\n\n" + +env_info_fmt += """ +============================== + vLLM Info +============================== +ROCM Version : {rocm_version} +Neuron SDK Version : {neuron_sdk_version} +vLLM Version : {vllm_version} +vLLM Build Flags: + {vllm_build_flags} +GPU Topology: + {gpu_topo} + +============================== + Environment Variables +============================== +{env_vars} +""".strip() + + +def pretty_str(envinfo): + + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None + for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available( + ) and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict['pip_packages'] = replace_if_empty( + mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty( + mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend( + mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend( + mutable_dict['conda_packages'], '[conda] ') + mutable_dict['cpu_info'] = envinfo.cpu_info + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr( + torch.utils, '_crash_handler'): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [ + os.path.join(minidump_dir, dump) + for dump in os.listdir(minidump_dir) + ] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime( + '%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" + print(msg, file=sys.stderr) + + +if __name__ == '__main__': + main() diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py new file mode 100644 index 0000000..ce4e50a --- /dev/null +++ b/vllm/compilation/activation_quant_fusion.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, + register_replacement) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +def silu_mul_pattern_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + return at[1] + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp8(*args, **kwargs): + fp8 = current_platform.fp8_dtype() + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +class ActivationQuantFusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="activation_quant_fusion_pass") + + inputs = [ + empty_fp8(5, 4), # Quant output + empty_bf16(5, 4), # Silu_and_mul output + empty_bf16(5, 4), # Input + empty_fp32(1, 1) # Scale + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_act_quant_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns in ActivationQuantFusionPass", + count) + + self.dump_graph(graph, "after_act_quant_fusion") + self.end_and_log() diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py new file mode 100644 index 0000000..a2bb053 --- /dev/null +++ b/vllm/compilation/backends.py @@ -0,0 +1,610 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import dataclasses +import os +import pprint +import time +from collections.abc import Sequence +from contextlib import contextmanager +from typing import Any, Callable, Optional + +import torch +import torch.fx as fx +from torch._dispatch.python import enable_python_dispatcher + +import vllm.envs as envs +from vllm.config import CompilationConfig, VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname + +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor, InductorStandaloneAdaptor) +from .counter import compilation_counter +from .inductor_pass import InductorPass +from .pass_manager import PostGradPassManager + +logger = init_logger(__name__) + + +def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + if compilation_config.use_inductor: + if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( + "2.8.0.dev"): + logger.debug("Using InductorStandaloneAdaptor") + return InductorStandaloneAdaptor() + else: + logger.debug("Using InductorAdaptor") + return InductorAdaptor() + else: + logger.debug("Using EagerAdaptor") + return EagerAdaptor() + + +class CompilerManager: + """ + A manager to manage the compilation process, including + caching the compiled graph, loading the compiled graph, + and compiling the graph. + + The cache is a dict mapping + `(runtime_shape, graph_index, backend_name)` + to `any_data` returned from the compiler. + + When serializing the cache, we save it to a Python file + for readability. We don't use json here because json doesn't + support int as key. + """ + + def __init__(self, compilation_config: CompilationConfig): + self.cache: dict[tuple[Optional[int], int, str], Any] = dict() + self.is_cache_updated = False + self.compilation_config = compilation_config + self.compiler = make_compiler(compilation_config) + + def compute_hash(self, vllm_config: VllmConfig) -> str: + return self.compiler.compute_hash(vllm_config) + + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): + """ + Initialize the cache directory for the compiler. + + The organization of the cache directory is as follows: + cache_dir=/path/to/hash_str/rank_i_j/prefix/ + inside cache_dir, there will be: + - vllm_compile_cache.py + - computation_graph.py + - transformed_code.py + + for multiple prefixes, they can share the same + base cache dir of /path/to/hash_str/rank_i_j/ , + to store some common compilation artifacts. + """ + + self.disable_cache = disable_cache + self.cache_dir = cache_dir + self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + # load the cache from the file + with open(self.cache_file_path) as f: + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache(cache_dir=cache_dir, + disable_cache=disable_cache, + prefix=prefix) + + def save_to_file(self): + if self.disable_cache or not self.is_cache_updated: + return + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + with open(self.cache_file_path, "w") as f: + f.write(data) + + def load(self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Optional[Callable]: + if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + return None + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, runtime_shape) + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), self.compiler.name, + handle) + return compiled_graph + + def compile(self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None) -> Any: + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) + if compiled_graph is not None: + if graph_index == num_graphs - 1: + # after loading the last graph for this shape, record the time. + # there can be multiple graphs due to piecewise compilation. + now = time.time() + elapsed = now - compilation_start_time + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", str(runtime_shape), elapsed) + return compiled_graph + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + # Let compile_fx generate a key for us + maybe_key = None + else: + maybe_key = \ + f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, runtime_shape, + maybe_key) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, + self.compiler.name)] = handle + self.is_cache_updated = True + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) + + return compiled_graph + + +@dataclasses.dataclass +class SplitItem: + submod_name: str + graph_id: int + is_splitting_graph: bool + graph: fx.GraphModule + + +def split_graph(graph: fx.GraphModule, + ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == 'call_function' and str(node.target) in ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + + # `keep_original_order` is important! + # otherwise pytorch might reorder the nodes and + # the semantics of the graph will change when we + # have mutations in the graph + split_gm = torch.fx.passes.split_module.split_module( + graph, + None, + lambda node: node_to_subgraph_id[node], + keep_original_order=True) + + outputs = [] + + names = [name for (name, module) in split_gm.named_modules()] + + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue + + module = getattr(split_gm, name) + + graph_id = int(name.replace("submod_", "")) + outputs.append( + SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) + + return split_gm, outputs + + +# we share the global graph pool among all the backends +global_graph_pool = None + +compilation_start_time = 0.0 + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + + NOTE: the order in `compile_submod_names` matters, because + it will be used to determine the order of the compiled piecewise + graphs. The first graph will handle logging, and the last graph + has some special cudagraph output handling. + """ + + def __init__(self, module: torch.fx.GraphModule, + compile_submod_names: list[str], vllm_config: VllmConfig, + graph_pool, vllm_backend: "VllmBackend"): + super().__init__(module) + from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.vllm_config = vllm_config + self.vllm_backend = vllm_backend + # When True, it annoyingly dumps the torch.fx.Graph on errors. + self.extra_traceback = False + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + with self.fake_mode, enable_python_dispatcher(): + return super().run(*fake_args) + + def call_module(self, target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, + ...], kwargs: dict[str, Any]) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + index = self.compile_submod_names.index(target) + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + global compilation_start_time + compiled_graph_for_general_shape = self.vllm_backend.\ + compiler_manager.compile( + submod, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None) + + piecewise_backend = resolve_obj_by_qualname( + current_platform.get_piecewise_backend_cls()) + self.module.__dict__[target] = piecewise_backend( + submod, self.vllm_config, self.graph_pool, index, + len(self.compile_submod_names), sym_shape_indices, + compiled_graph_for_general_shape, self.vllm_backend) + + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + +# the tag for the part of model being compiled, +# e.g. backbone/eagle_head +model_tag: str = "backbone" + + +@contextmanager +def set_model_tag(tag: str): + """Context manager to set the model tag.""" + global model_tag + assert tag != model_tag, \ + f"Model tag {tag} is the same as the current tag {model_tag}." + old_tag = model_tag + model_tag = tag + try: + yield + finally: + model_tag = old_tag + + +class VllmBackend: + """The compilation backend for `torch.compile` with vLLM. + It is used for compilation level of `CompilationLevel.PIECEWISE`, + where we customize the compilation. + + The major work of this backend is to split the graph into + piecewise graphs, and pass them to the piecewise backend. + + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. + """ + + vllm_config: VllmConfig + compilation_config: CompilationConfig + graph_pool: Any + _called: bool = False + # the graph we compiled + graph: fx.GraphModule + # the stiching graph module for all the piecewise graphs + split_gm: fx.GraphModule + piecewise_graphs: list[SplitItem] + returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] + sym_tensor_indices: list[int] + input_buffers: list[torch.Tensor] + compiler_manager: CompilerManager + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + + # if the model is initialized with a non-empty prefix, + # then usually it's enough to use that prefix, + # e.g. launguage_model, vision_model, etc. + # when multiple parts are initialized as independent + # models, we need to use the model_tag to distinguish + # them, e.g. backbone (default), eagle_head, etc. + self.prefix = prefix or model_tag + + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = current_platform.graph_pool_handle() + + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = global_graph_pool + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() + + self.sym_tensor_indices = [] + self.input_buffers = [] + + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config) + + # `torch.compile` is JIT compiled, so we don't need to + # do anything here + + def configure_post_pass(self): + config = self.compilation_config + self.post_grad_pass_manager.configure(self.vllm_config) + + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. + inductor_config = config.inductor_compile_config + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + # Config should automatically wrap all inductor passes + if isinstance(inductor_config[PASS_KEY], PostGradPassManager): + assert (inductor_config[PASS_KEY].uuid() == + self.post_grad_pass_manager.uuid()) + else: + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager + + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + + vllm_config = self.vllm_config + if not self.compilation_config.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + + factors = [] + # 0. factors come from the env, for example, The values of + # VLLM_PP_LAYER_PARTITION will affects the computation graph. + env_hash = envs.compute_hash() + factors.append(env_hash) + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + config_hash = vllm_config.compute_hash() + factors.append(config_hash) + + # 2. factors come from the code files that are traced by Dynamo ( + # it mainly summarizes how the model is used in forward pass) + forward_code_files = list( + sorted(self.compilation_config.traced_files)) + self.compilation_config.traced_files.clear() + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", + "\n".join(forward_code_files)) + hash_content = [] + for filepath in forward_code_files: + hash_content.append(filepath) + if filepath == "": + # This means the function was dynamically generated, with + # e.g. exec(). We can't actually check these. + continue + with open(filepath) as f: + hash_content.append(f.read()) + import hashlib + code_hash = hashlib.md5("\n".join(hash_content).encode(), + usedforsecurity=False).hexdigest() + factors.append(code_hash) + + # 3. compiler hash + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + factors.append(compiler_hash) + + # combine all factors to generate the cache dir + hash_key = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, + "torch_compile_cache", + hash_key, + ) + self.compilation_config.cache_dir = cache_dir + + cache_dir = self.compilation_config.cache_dir + os.makedirs(cache_dir, exist_ok=True) + self.compilation_config.cache_dir = cache_dir + rank = vllm_config.parallel_config.rank + dp_rank = vllm_config.parallel_config.data_parallel_rank + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", + self.prefix) + os.makedirs(local_cache_dir, exist_ok=True) + self.compilation_config.local_cache_dir = local_cache_dir + + disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE + + if disable_cache: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info("Using cache directory: %s for vLLM's torch.compile", + local_cache_dir) + + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache, + self.prefix) + + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done + compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_config.compilation_time += dynamo_time + + # we control the compilation process, each instance can only be + # called once + assert not self._called, "VllmBackend can only be called once" + + self.graph = graph + self.configure_post_pass() + + self.split_gm, self.piecewise_graphs = split_graph( + graph, self.compilation_config.splitting_ops) + + from torch._dynamo.utils import lazy_format_graph_code + + # depyf will hook lazy_format_graph_code and dump the graph + # for debugging, no need to print the graph here + lazy_format_graph_code("before split", self.graph) + lazy_format_graph_code("after split", self.split_gm) + + compilation_counter.num_piecewise_graphs_seen += len( + self.piecewise_graphs) + submod_names_to_compile = [ + item.submod_name for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + # propagate the split graph to the piecewise backend, + # compile submodules with symbolic shapes + PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, + self.vllm_config, self.graph_pool, + self).run(*example_inputs) + + graph_path = os.path.join(local_cache_dir, "computation_graph.py") + if not os.path.exists(graph_path): + # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa + # use `print_readable` because it can include submodules + src = "from __future__ import annotations\nimport torch\n" + \ + self.split_gm.print_readable(print_output=False) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) + + logger.debug("Computation graph saved to %s", graph_path) + + self._called = True + + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: + return self.split_gm + + # if we need to copy input buffers for cudagraph + from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() + fake_args = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in example_inputs + ] + + # index of tensors that have symbolic shapes (batch size) + # for weights and static buffers, they will have concrete shapes. + # symbolic shape only happens for input tensors. + from torch.fx.experimental.symbolic_shapes import is_symbolic + self.sym_tensor_indices = [ + i for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ + any(is_symbolic(d) for d in x.size()) + ] + + # compiler managed cudagraph input buffers + # we assume the first run with symbolic shapes + # has the maximum size among all the tensors + self.input_buffers = [ + example_inputs[x].clone() for x in self.sym_tensor_indices + ] + + # this is the callable we return to Dynamo to run + def copy_and_call(*args): + list_args = list(args) + for i, index in enumerate(self.sym_tensor_indices): + runtime_tensor = list_args[index] + runtime_shape = runtime_tensor.shape[0] + static_tensor = self.input_buffers[i][:runtime_shape] + + # copy the tensor to the static buffer + static_tensor.copy_(runtime_tensor) + + # replace the tensor in the list_args to the static buffer + list_args[index] = static_tensor + return self.split_gm(*list_args) + + return copy_and_call diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py new file mode 100644 index 0000000..4d7aeeb --- /dev/null +++ b/vllm/compilation/base_piecewise_backend.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +import torch.fx as fx + +from vllm.compilation.backends import VllmBackend +from vllm.config import VllmConfig + + +class AbstractPiecewiseBackend(Protocol): + """ + PiecewiseBackend interface that allows platforms to extend + piecewise static graph. + """ + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, **kwargs): + """ + Initializes the PiecewiseBackend class with compilation and + execution-related configurations. + + This class handles piecewise compilation, graph capturing, + and dispatching for specific input shapes. + + Args: + graph (fx.GraphModule): The graph represented in fx. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + piecewise_compile_index (int): + Index of the current piecewise subgraph. + total_piecewise_compiles (int): + Total number of piecewise-compiled graphs. + sym_shape_indices (list[int]): + Indices of symbolic shape. + compiled_graph_for_general_shape (Callable): + Callable that executes the graph compiled for general shapes. + vllm_backend (VllmBackend): + Backend compiler that manages compilation and graph runtime + for vLLM. + + Keyword Args: + kwargs: Additional keyword arguments reserved for future + extensions or custom platforms. + """ + raise NotImplementedError + + def __call__(self, *args) -> Any: + """Executes the compiled graph for given input args. + + If this is the first invocation, executes the general compiled graph + and initiates the compilation process tracking. For subsequent calls, + dynamically dispatches execution to either a compiled graph or a static + graph based on the input shape. + + Args: + *args: Variable length input arguments to be passed into the + graph. The symbolic shape is expected to be in position + `sym_shape_indices[0]`. + + Returns: + Any: Output of the executed graph. This can be from the general + compiled graph, a specialized compiled version for the given shape, + or a replayed static graph. + """ + raise NotImplementedError diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py new file mode 100644 index 0000000..f754fc2 --- /dev/null +++ b/vllm/compilation/collective_fusion.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch.distributed._symmetric_memory import enable_symm_mem_for_group + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class BasePattern: + + def __init__(self, dtype: torch.dtype, device: str): + self.dtype = dtype + self.device = device + self.tp = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + +class GEMMReduceScatterPattern(BasePattern): + + def get_inputs(self): + mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) + mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [mul, mm_weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): + mm = torch.ops.aten.mm.default(mul, mm_weight) + reduce_scatter = torch.ops.vllm.reduce_scatter.default( + mm, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + return reduce_scatter + + def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): + gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( + mul, + mm_weight, + "avg", + scatter_dim=0, + group_name=self.tp.device_group.group_name, + ) + + return gemm_rs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllGatherGEMMPattern(BasePattern): + + def get_inputs(self): + x = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + return [x, weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + x: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_gather = torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name) + + return torch.ops.aten.mm.default(all_gather, weight) + + def replacement( + x: torch.Tensor, + weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [weight], + gather_dim=0, + group_name=self.tp.device_group.group_name, + ) + return mm_outputs + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AsyncTPPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + + # Enable symmetric memory for the TP process group + enable_symm_mem_for_group(get_tp_group().device_group.group_name) + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="async_tp_pass") + GEMMReduceScatterPattern(self.model_dtype, + self.device).register(self.patterns) + + AllGatherGEMMPattern(self.model_dtype, + self.device).register(self.patterns) + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + # only do replace for specific shapes + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_async_tp_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_async_tp_pass") + self.end_and_log() diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py new file mode 100644 index 0000000..fd39a61 --- /dev/null +++ b/vllm/compilation/compiler_interface.py @@ -0,0 +1,564 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.config import VllmConfig +from vllm.utils import is_torch_equal_or_newer + +from .inductor_pass import pass_context + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + + prefix can be used in combination with cache_dir to figure out the base + cache directory, e.g. there're multiple parts of model being compiled, + but we want to share the same cache directory for all of them. + + e.g. + cache_dir = "/path/to/dir/backbone", prefix = "backbone" + cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head" + """ + pass + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + Gather all the relevant information from the vLLM config, + to compute a hash so that we can cache the compiled model. + + See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash] + to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. + """ + return None, None + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: list[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +def get_inductor_factors() -> list[Any]: + factors: list[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class InductorStandaloneAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler. + Requires PyTorch 2.8+. + This is not on by default yet, but we plan to turn it on by default for + PyTorch 2.8. + + Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off. + """ + name = "inductor_standalone" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): + self.cache_dir = cache_dir + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + set_inductor_config(current_config, runtime_shape) + + if isinstance(runtime_shape, int): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_tracing_context" + + from torch._inductor import standalone_compile + with pass_context(runtime_shape): + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}) + + # Save the compiled artifact to disk in the specified path + assert key is not None + path = os.path.join(self.cache_dir, key) + compiled_graph.save(path=path, format="unpacked") + return compiled_graph, (key, path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + path = handle[1] + inductor_compiled_graph = torch._inductor.CompiledArtifact.load( + path=path, format="unpacked") + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + def compiled_graph_wrapper(*args): + graph_output = inductor_compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. + """ + name = "inductor" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, + cache_dir: str, + disable_cache: bool = False, + prefix: str = ""): + self.cache_dir = cache_dir + self.prefix = prefix + self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(self.base_cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 + from torch._inductor.compile_fx import compile_fx + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + set_inductor_config(current_config, runtime_shape) + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.base_cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + if cell.cell_contents.__code__.co_filename.startswith( + self.base_cache_dir): + # this is the real file path compiled from Inductor + file_path = cell.cell_contents.__code__.co_filename + break + return inductor_compiled_graph + + hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa + elif torch.__version__ >= "2.6": + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner( + *args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + compiled_fn = inductor_compiled_graph.current_callable + file_path = compiled_fn.__code__.co_filename # noqa + if not file_path.startswith(self.base_cache_dir): + # hooked in the align_inputs_from_check_idxs function + # in torch/_inductor/utils.py + for cell in compiled_fn.__closure__: + if not callable(cell.cell_contents): + continue + code = cell.cell_contents.__code__ + if code.co_filename.startswith( + self.base_cache_dir): + # this is the real file path + # compiled from Inductor + file_path = code.co_filename + break + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + + # Dynamo metrics context, see method for more details. + stack.enter_context(self.metrics_context()) + + # Disable remote caching. When these are on, on remote cache-hit, + # the monkey-patched functions never actually get called. + # vLLM today assumes and requires the monkey-patched functions to + # get hit. + # TODO(zou3519): we're going to replace this all with + # standalone_compile sometime. + if is_torch_equal_or_newer("2.6"): + stack.enter_context( + torch._inductor.config.patch(fx_graph_remote_cache=False)) + stack.enter_context( + torch._functorch.config.patch( + enable_remote_autograd_cache=False)) + + with pass_context(runtime_shape): + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch + # compilation cache. So turn off the checks if we disable the + # compilation cache. + if not envs.VLLM_DISABLE_COMPILE_CACHE: + if hash_str is None: + raise RuntimeError( + "vLLM failed to compile the model. The most " + "likely reason for this is that a previous compilation " + "failed, leading to a corrupted compilation artifact. " + "We recommend trying to " + "remove ~/.cache/vllm/torch_compile_cache and try again " + "to see the real issue. ") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") + return compiled_graph, (hash_str, file_path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + from torch._inductor.codecache import FxGraphCache + with ExitStack() as exit_stack: + exit_stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) + + # Dynamo metrics context, see method for more details. + exit_stack.enter_context(self.metrics_context()) + + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import ( + CompiledFxGraphConstantsWithGm) + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + def metrics_context(self) -> contextlib.AbstractContextManager: + """ + This method returns the Dynamo metrics context (if it exists, + otherwise a null context). It is used by various compile components. + Present in torch>=2.6, it's used inside FxGraphCache in + torch==2.6 (but not after). It might also be used in various other + torch.compile internal functions. + + Because it is re-entrant, we always set it (even if entering via Dynamo + and the context was already entered). We might want to revisit if it + should be set at a different level of compilation. + + This is likely a bug in PyTorch: public APIs should not rely on + manually setting up internal contexts. But we also rely on non-public + APIs which might not provide these guarantees. + """ + if is_torch_equal_or_newer("2.6"): + import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() + else: + return contextlib.nullcontext() + + +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + +class EagerAdaptor(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_eager_compiles += 1 + # we don't need to compile the graph, just return the graph itself. + # It does not support caching, return None for the handle. + return graph, None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py new file mode 100644 index 0000000..9d7a256 --- /dev/null +++ b/vllm/compilation/counter.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +import dataclasses +from contextlib import contextmanager + + +@dataclasses.dataclass +class CompilationCounter: + num_models_seen: int = 0 + num_graphs_seen: int = 0 + # including the splitting ops + num_piecewise_graphs_seen: int = 0 + # not including the splitting ops + num_piecewise_capturable_graphs_seen: int = 0 + num_backend_compilations: int = 0 + # Number of gpu_model_runner attempts to trigger CUDAGraphs capture + num_gpu_runner_capture_triggers: int = 0 + # Number of CUDAGraphs captured + num_cudagraph_captured: int = 0 + # InductorAdapter.compile calls + num_inductor_compiles: int = 0 + # EagerAdapter.compile calls + num_eager_compiles: int = 0 + + def clone(self) -> "CompilationCounter": + return copy.deepcopy(self) + + @contextmanager + def expect(self, **kwargs): + old = self.clone() + yield + for k, v in kwargs.items(): + assert getattr(self, k) - getattr(old, k) == v, ( + f"{k} not as expected, before it is {getattr(old, k)}" + f", after it is {getattr(self, k)}, " + f"expected diff is {v}") + + +compilation_counter = CompilationCounter() diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py new file mode 100644 index 0000000..8c49ea6 --- /dev/null +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +class CUDAPiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture cudagraph for different shapes. + + If a shape needs both compilation and cudagraph, we will + compile it first, and then capture cudagraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + # Skip CUDA graphs if this entry doesn't use them OR + # if we're supposed to skip them globally + skip_cuda_graphs = get_forward_context().skip_cuda_graphs + if not entry.use_cudagraph or skip_cuda_graphs: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 0000000..5620f7e --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +from typing import Callable, Optional, TypeVar, Union, overload +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator + +from vllm import envs +from vllm.compilation.counter import compilation_counter +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.forward_context import get_forward_context, get_profilling +from vllm.config import CompilationLevel, VllmConfig +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + +from .monitor import start_monitoring_torch_compile + +logger = init_logger(__name__) + +_T = TypeVar("_T", bound=type[nn.Module]) + + +@overload +def support_torch_compile( + *, + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], +) -> Callable[[_T], _T]: + ... + + +@overload +def support_torch_compile(cls: _T) -> _T: + ... + + +def support_torch_compile( + cls: Optional[_T] = None, + *, + dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, +) -> Union[Callable[[_T], _T], _T]: + """ + A decorator to add support for compiling the forward method of a class. + + Usage 1: use directly as a decorator without arguments: + + ```python + @support_torch_compile + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + Usage 2: use as a decorator with arguments: + + ```python + @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic + dimensions of the argument. The dynamic dimensions can be either a single + integer or a list of integers. + + if `dynamic_arg_dims` is `None`, it is inferred from the type annotation + of the `forward` method, based on the following default rules: + + - if the argument is annotated as `torch.Tensor` or + `Optional[torch.Tensor]`, the first dimension will be + marked as dynamic. + - if the argument is annotated as `IntermediateTensors`, the first + dimension of all the tensors in the intermediate tensors + will be marked as dynamic. + + During runtime, when we actually mark dimensions of tensors, + it depends on the value of arguments: + + - if it is a single integer (can be negative), the corresponding dimension + of the argument will be marked as dynamic. + - if it is `None`, ignored. + - if it is `IntermediateTensors`, all the tensors in the intermediate + tensors will be marked as dynamic. + - otherwise, it will raise an error. + + NOTE: if an argument is `None`, it should always be passed as `None` during + the lifetime of the model, otherwise, it cannot be captured as a single + computation graph. + """ + + def cls_decorator_helper(cls: _T) -> _T: + # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` + # to avoid too much indentation for `_support_torch_compile`` + if not hasattr(cls, 'forward'): + raise TypeError("decorated class should have a forward method.") + sig = inspect.signature(cls.forward) + inferred_dynamic_arg_dims = dynamic_arg_dims + if inferred_dynamic_arg_dims is None: + inferred_dynamic_arg_dims = {} + for k, v in sig.parameters.items(): + if v.annotation in [ + torch.Tensor, Optional[torch.Tensor], + IntermediateTensors, Optional[IntermediateTensors] + ]: + inferred_dynamic_arg_dims[k] = 0 + + logger.debug(("Inferred dynamic dimensions for " + "forward method of %s: %s"), cls, + list(inferred_dynamic_arg_dims.keys())) + + if len(inferred_dynamic_arg_dims) == 0: + raise ValueError( + "No dynamic dimensions found in the forward method of " + f"{cls}. Please provide dynamic_arg_dims explicitly.") + + for k in inferred_dynamic_arg_dims: + if k not in sig.parameters: + raise ValueError( + f"Argument {k} not found in the forward method of {cls}") + return _support_torch_compile(cls, inferred_dynamic_arg_dims) + + if cls is not None: + # use `support_torch_compile` as a decorator without arguments + assert isinstance(cls, type) + return cls_decorator_helper(cls) + + return cls_decorator_helper + + +def _support_torch_compile( + cls: _T, + dynamic_arg_dims: dict[str, Union[int, list[int]]], +) -> _T: + """ + A decorator to add support for compiling the forward method of a class. + """ + if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + # support decorating multiple times + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + self.do_not_compile = \ + vllm_config.compilation_config.level in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo() + if self.do_not_compile: + return + compilation_counter.num_models_seen += 1 + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) + + cls.__init__ = __init__ + + def __call__(self, *args, **kwargs): + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + skip_cuda_graphs = get_forward_context().skip_cuda_graphs + if envs.VLLM_ENABLE_TBO and skip_cuda_graphs: + return self.forward(*args, **kwargs) + + if self.do_not_compile or torch.compiler.is_compiling() or get_profilling(): + return self.forward(*args, **kwargs) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + dims = [dims] if isinstance(dims, int) else dims + if isinstance(arg, torch.Tensor): + # In case dims is specified with negative indexing + dims = [ + arg.ndim + dim if dim < 0 else dim for dim in dims + ] + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + # In case dims is specified with negative indexing + dims = [ + tensor.ndim + dim if dim < 0 else dim + for dim in dims + ] + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config) + logger.debug("Start compiling function %s", + self.original_code_object) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + # it seems Dynamo reuse the compilation across instances, + # while we need to make sure the compiled code is not reused. + # we need to control all the compilation of the model. + torch._dynamo.eval_frame.remove_from_cache( + self.original_code_object) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(*args, **kwargs) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py new file mode 100644 index 0000000..48f32c8 --- /dev/null +++ b/vllm/compilation/fix_functionalization.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class FixFunctionalizationPass(VllmInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: list[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # rotary_embedding is a special case: the two mutating inputs + # are query and key, which are slices of mm_node. + # While functionalized, results at[1] and at[2] are scattered + # back into mm_node. After de-functionalization, we can just + # use mm_node directly. + for idx, user in self.getitem_users(node).items(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + # rms_norm replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: 'input', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + # elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + # mutated_args = {1: 'result', 2: 'residual'} + # self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target in [ + torch.ops._C.rms_norm.default, + # torch.ops._C.rms_norm_static_fp8_quant.default, + ]: + mutated_args = {1: 'result'} + self.defunctionalize(graph, node, mutated_args) + # For some reason we need to specify the args for both + # silu_and_mul and silu_and_mul_quant. The kwargs + # pathway gets the wrong answer. + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'result'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'input')) + # elif at_target == torch.ops._C.silu_and_mul_quant.default: + # mutated_args = {1: 'result'} + # self.defunctionalize(graph, + # node, + # mutated_args, + # args=('result', 'input', 'scale')) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug("De-functionalized %s nodes, removed %s nodes", count, + count_removed) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, + Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args(self, node: torch.fx.Node, + mutated_args: dict[int, + Union[torch.fx.Node, + str]]): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func(node, auto_functionalized), \ + f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg + for arg in args) + graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py new file mode 100644 index 0000000..97b6c9d --- /dev/null +++ b/vllm/compilation/fusion.py @@ -0,0 +1,645 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, ClassVar, NamedTuple, Optional + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fx_utils import find_getitem_maybe +from .multi_output_match import MultiOutputMatch +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + + +# Use proxy as NamedTuple direct subclasses cannot have static members +class _GroupShape(NamedTuple): + row: int + col: int + + +class GroupShape(_GroupShape): + """ + This class describes the quantization group shape. + It includes static members for common shapes (per-tensor, per-token). + """ + + # Aliases for common quantization group shapes + PER_TENSOR: ClassVar['GroupShape'] + PER_TOKEN: ClassVar['GroupShape'] + + +GroupShape.PER_TENSOR = GroupShape(-1, -1) +GroupShape.PER_TOKEN = GroupShape(1, -1) + + +class QuantKey(NamedTuple): + """ + Named tuple for identifying the type of quantization. + dtype: quantized data type + static: static quantization if True, dynamic if False + group_shape: quantization group shape + symmetric: symmetric if True, asymmetric if False + + TODO(luka) use QuantDescriptor once standardized: + https://github.com/vllm-project/vllm/issues/8913 + + """ + dtype: torch.dtype + static: bool + group_shape: GroupShape + symmetric: bool = True + + def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + + return (f"QuantKey({'static' if self.static else 'dynamic'}," + f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," + f"{'a' if not self.symmetric else ''}symmetric)") + + +# kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) +# kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) +# kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) + +QUANT_OPS: dict[QuantKey, OpOverload] = { + # kFp8StaticTensorSym: + # torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + # kFp8DynamicTensorSym: + # torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + # kFp8DynamicTokenSym: + # torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 +} + + +class FusedRMSQuantKey(NamedTuple): + """ + Named tuple for identifying the type of RMSNorm + quant fusion. + quant: type of quantization + fused_add: does the op also perform the residual add + """ + quant: QuantKey + fused_add: bool + + def __str__(self): + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") + + +FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { + # FusedRMSQuantKey(kFp8StaticTensorSym, False): + # torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + # FusedRMSQuantKey(kFp8StaticTensorSym, True): + # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + # FusedRMSQuantKey(kFp8DynamicTokenSym, False): + # torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + # FusedRMSQuantKey(kFp8DynamicTokenSym, True): + # torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 +} + + +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + assert isinstance(quant_op, OpOverload) + assert isinstance(fused_op, OpOverload) + self.QUANT_OP = quant_op # in-place quant op + self.FUSED_OP = fused_op # in-place fused quant op + + def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node, + int]], + **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence, others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + +class RMSNormQuantPattern: + + def __init__(self, epsilon: float, key: FusedRMSQuantKey): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype + + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] + + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" + self.FUSED_OP = FUSED_OPS[key] + + +class RMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey( + dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric)) + super().__init__(epsilon, fused_key) + + def register(self, pm_pass: PatternMatcherPass): + # Cannot use methods, as the self argument affects tracing + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + + # result + return at2[1] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result + return at[1] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) + + +class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey( + dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale) + + # result, residual + return at1[1], at[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and residual. + # The auto_fn node returns a tuple of (None, result, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + # 0 is always None + fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + **kwargs, + epsilon=rms_node.kwargs["epsilon"]) + + +class RMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True): + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=False, + group_shape=group_shape, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale, + scale_ub=None) + + # result, scale + return at2[1], at2[2] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) + + # result, scale + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 1 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and scale. + # The auto_fn node returns a tuple of (None, result, scale). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + del kwargs["result_rms"] # not used in the fused op + + fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + residual=None, # not used but required + **kwargs) + + +class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=False, + group_shape=group_shape, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale, + scale_ub=None) + + # result, residual, scale + return at1[1], at[2], at1[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) + + # result, residual, scale + return at[1], at[3], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract result, scale, and residual. + # The auto_fn node returns a tuple (None, result, scale, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + # residual_node_new = at[3] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + fused_return_mapping = { + 1: (quant_node, 1), # result + 2: (quant_node, 2), # scale + 3: (rms_node, 2), # residual + } + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + **kwargs) + + +class FusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + It also manually processes multi-output matches, as those are broken in + the torch pattern matcher. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + _instance: 'Optional[FusionPass]' = None + + @classmethod + def instance(cls, config: VllmConfig): + """ + Get the singleton instance of the FusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = FusionPass(config) + else: + cls._instance.pass_config = config.compilation_config.pass_config + return cls._instance + + def __init__(self, config: VllmConfig): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + + self.matches: list[MultiOutputMatch] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="fusion_pass") + + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + # RMSNormStaticQuantPattern(epsilon, + # FP8_DTYPE).register(self.patterns) + + # Matches for patterns below have 2 or more outputs, + # so we need to process them manually (see process_matches) + + # Fuse rms_norm + static fp8 quant + # FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + # self.patterns, self.record_match) + + # Fuse rms_norm + dynamic per-token fp8 quant + # RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + # self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + # FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( + # self.patterns, self.record_match) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def record_match(self, match: MultiOutputMatch) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) + + # Return False to prevent automatic replacement. + return False + + def process_matches(self, graph: fx.Graph): + """ + Manually process multi-output matches and replace them with fused nodes. + See MultiOutputMatch for more details. + """ + for match in self.matches: + match.process() + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in self.matches + for node in match.match.nodes) + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_pattern_match") + + # Manually process multi-output matches (and run DCE) + self.process_matches(graph) + logger.debug("Post-processed %s matches", len(self.matches)) + self.dump_graph(graph, "after_fusion") + self.matches.clear() + self.end_and_log() diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py new file mode 100644 index 0000000..79518b6 --- /dev/null +++ b/vllm/compilation/fusion_attn.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch._inductor.pattern_matcher as pm +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._subclasses.fake_tensor import (FakeTensorMode, + unset_fake_temporarily) + +from vllm.attention import Attention +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + +ATTN_OP = torch.ops.vllm.unified_attention_with_output.default +RESHAPE_OP = torch.ops.aten.reshape.default + + +class AttentionStaticQuantPattern: + + def __init__( + self, + layer_name: str, + num_heads: int, + head_size: int, + quant_dtype: torch.dtype, + symmetric=True, + ): + self.layer_name = layer_name + self.num_heads = num_heads + self.head_size = head_size + self.quant_dtype = quant_dtype + self.quant_key = QuantKey(dtype=quant_dtype, + static=True, + group_shape=GroupShape.PER_TENSOR, + symmetric=symmetric) + assert self.quant_key in QUANT_OPS, \ + f"unsupported quantization scheme {self.quant_key}" + self.QUANT_OP = QUANT_OPS[self.quant_key] + + def empty_quant(self, *args, **kwargs): + kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + return torch.empty(*args, **kwargs) + + def register_if_supported(self, pm_pass: PatternMatcherPass, + layer: Attention): + if layer.impl.fused_output_quant_supported(self.quant_dtype, + self.quant_key.static, + self.quant_key.group_shape): + self._register(pm_pass) + + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + scale: torch.Tensor): + view_7 = RESHAPE_OP(output_attn, + [-1, self.num_heads, self.head_size]) + + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=view_7, + layer_name=self.layer_name, + output_scale=None) + attn_out_view = RESHAPE_OP(at1[1], + [-1, self.num_heads * self.head_size]) + + at2 = auto_functionalized(self.QUANT_OP, + result=output_quant, + input=attn_out_view, + scale=scale) + return at2[1] + + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + scale: torch.Tensor): + view_7 = RESHAPE_OP(output_quant, + [-1, self.num_heads, self.head_size]) + + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=view_7, + layer_name=self.layer_name, + output_scale=scale) + + return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) + + # Need custom fake mode, otherwise tracing happens with real tensors. + # That would not work for the unified_attention custom op. + with unset_fake_temporarily(), FakeTensorMode(): + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads * self.head_size), # attn_output + self.empty_quant(5, self.num_heads * + self.head_size), # quant_output + empty_fp32(1, 1) # scale + ] + + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + + pm.register_replacement( + pattern, replacement, inputs, + wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + + +class AttnFusionPass(VllmInductorPass): + """ + This pass fuses post-attention quantization onto attention if supported. + + It uses the pattern matcher and matches each layer manually, as strings + cannot be wildcarded. This also lets us check support on attention layers + upon registration instead of during pattern matching. + + Currently, only static fp8 quant is supported, but patterns could easily be + added for other quant schemes and dtypes. The bigger hurdle for wider + support are attention kernels, which need to support fusing output quant. + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + self.static_fwd_ctx = config.compilation_config.static_forward_context + + self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass") + + for key, layer in self.static_fwd_ctx.items(): + pattern = AttentionStaticQuantPattern(key, layer.num_heads, + layer.head_size, + current_platform.fp8_dtype()) + pattern.register_if_supported(self.patterns, layer) + if len(self.static_fwd_ctx) == 0: + logger.warning( + "Attention + quant fusion is enabled, but " + "CompilationConfig.static_forward_context is empty. " + "Cannot access attention layers so no fusion " + "patterns were registered.") + + def __call__(self, graph: torch.fx.graph.Graph) -> None: + self.begin() + self.dump_graph(graph, "before_attn_fusion") + + count = self.patterns.apply(graph) + logger.debug("Fused quantization onto %s attention nodes", count) + self.dump_graph(graph, "after_attn_fusion") + self.end_and_log() diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py new file mode 100644 index 0000000..2db8b54 --- /dev/null +++ b/vllm/compilation/fx_utils.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import operator +from collections.abc import Iterable, Iterator +from typing import Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +def is_auto_func(node: fx.Node, op: OpOverload) -> bool: + return is_func(node, auto_functionalized) and node.args[0] == op + + +# Returns the first specified node with the given op (if it exists) +def find_specified_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if node.target == op: + return node + return None + + +# Returns the first specified node with the given op +def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_specified_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret + + +# An auto-functionalization-aware utility for finding nodes with a specific op +def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]: + if not op._schema.is_mutable: + yield from graph.find_nodes(op="call_function", target=op) + + for n in graph.find_nodes(op="call_function", target=auto_functionalized): + if n.args[0] == op: + yield n + + +# Asserts that the node only has one user and returns it +# Even if a node has only 1 user, it might share storage with another node, +# which might need to be taken into account. +def get_only_user(node: fx.Node) -> fx.Node: + assert len(node.users) == 1 + return next(iter(node.users)) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py new file mode 100644 index 0000000..810d080 --- /dev/null +++ b/vllm/compilation/inductor_pass.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import inspect +import json +import types +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch +from torch import fx + +from vllm.utils import is_torch_equal_or_newer + +if is_torch_equal_or_newer("2.6"): + from torch._inductor.custom_graph_pass import CustomGraphPass +else: + # CustomGraphPass is not present in 2.5 or lower, import our version + from .torch25_custom_graph_pass import ( # noqa: E501 + Torch25CustomGraphPass as CustomGraphPass) + +_pass_context = None + + +class PassContext: + + def __init__(self, runtime_shape: Optional[int]): + self.runtime_shape = runtime_shape + + +def get_pass_context() -> PassContext: + """Get the current pass context.""" + assert _pass_context is not None + return _pass_context + + +@contextmanager +def pass_context(runtime_shape: Optional[int]): + """A context manager that stores the current pass context, + usually it is a list of sizes to specialize. + """ + global _pass_context + prev_context = _pass_context + _pass_context = PassContext(runtime_shape) + try: + yield + finally: + _pass_context = prev_context + + +class InductorPass(CustomGraphPass): + """ + A custom graph pass that uses a hash of its source as the UUID. + This is defined as a convenience and should work in most cases. + """ + + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.hexdigest() + + @staticmethod + def hash_dict(dict_: dict[Any, Any]): + """ + Utility method to hash a dictionary, can alternatively be used for uuid. + :return: A sha256 hash of the json rep of the dictionary. + """ + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + def is_applicable_for_shape(self, shape: Optional[int]): + return True + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + self._uuid = self.hash_source(callable) if uuid is None else uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py new file mode 100644 index 0000000..1e059b5 --- /dev/null +++ b/vllm/compilation/monitor.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time + +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + +context_manager = None +torch_compile_start_time: float = 0.0 + + +def start_monitoring_torch_compile(vllm_config: VllmConfig): + global torch_compile_start_time + torch_compile_start_time = time.time() + + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE and \ + compilation_config.debug_dump_path: + import depyf + path = os.path.join(compilation_config.debug_dump_path, + f"rank_{vllm_config.parallel_config.rank}") + global context_manager + context_manager = depyf.prepare_debug(path) + context_manager.__enter__() + + +def end_monitoring_torch_compile(vllm_config: VllmConfig): + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE: + logger.info("torch.compile takes %.2f s in total", + compilation_config.compilation_time) + global context_manager + if context_manager is not None: + context_manager.__exit__(None, None, None) + context_manager = None diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py new file mode 100644 index 0000000..6d18937 --- /dev/null +++ b/vllm/compilation/multi_output_match.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import abc +import operator +from abc import abstractmethod +from collections.abc import Iterable + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor import pattern_matcher as pm +from torch._ops import OpOverload +from torch.fx import Node + +from vllm.compilation.fx_utils import find_auto_fn + + +class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + + def __init__(self, match: pm.Match): + self.match = match + + @abstractmethod + def process(self): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + @property + def nodes(self) -> list[fx.Node]: + return self.match.nodes + + @property + def graph(self) -> fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(self.graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return self.graph.inserting_after(last_node_in_match) + + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> tuple[fx.Node, ...]: + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op: OpOverload, kwargs) -> Node: + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py new file mode 100644 index 0000000..4888d4d --- /dev/null +++ b/vllm/compilation/noop_elimination.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Union + +import torch.fx +from torch import SymInt + +from vllm.logger import init_logger + +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class NoOpEliminationPass(VllmInductorPass): + """ + This is an inductor pass that removes redundant reshape/slice operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. Additionally, torch internal no-op elimination pass does + not handle certain slice variants. + + Cases handled: + 1. A chain of reshapes is equivalent to the last reshape called on the + base tensor (input of the first reshape). + 2. A reshape that produces the shape of the input is redundant + 3. A slice that produces the shape of the input is redundant + + Example graph 1: + mul_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32]) + view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096]) + view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32]) + + Can be replaced with: + mul_1: "f16[s0, 4096]" = ... + view_3: "f16[s0, 128, 32]" = ... + + Example graph 2: + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Example graph 3: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0) + at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...) + out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0) + + Can be replaced with: + arg0: "s0" = SymInt(s0) + scaled_mm: "f16[s0, 4096]" = ... + at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...) + out: "f16[s0, 4096]" = at[1] + + TODO(luka): This is currently tested in test_fusion, + but separate tests could be good. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_noop_elimination") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + # Case 1: rewrite reshape chains to reshapes on the base tensor + input = node.args[0] + # If the input is a reshape, rebind to that node + if is_func(input, torch.ops.aten.reshape.default): + # The new input is guaranteed not to be a reshape, + # because we process nodes in order + node.update_arg(0, input.args[0]) + if len(input.users) == 0: + graph.erase_node(input) + count += 1 + + # Case 2: remove this reshape if it produces the original shape + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if self.all_dims_equivalent(shape, input_shape): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice.Tensor): + input, dim_index, start, end = node.args[:4] + input_shape = input.meta["val"].shape + i_dim = input_shape[dim_index] + + if start == 0 and self.dims_equivalent(end, i_dim): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + elif is_func(node, torch.ops.aten.slice_scatter.default): + base, view, dim_index, start, end = node.args[:5] + base_shape = base.meta["val"].shape + view_shape = view.meta["val"].shape + + view_dim = view_shape[dim_index] + + # Check that view fully covers base and the full view is used + # (if the view fully covered the base after slicing but was not + # fully used, we could replace slice_scatter with a simple slice + # but that's a niche case). + if (base_shape == view_shape and start == 0 + and self.dims_equivalent(end, view_dim)): + node.replace_all_uses_with(view) + graph.erase_node(node) + count += 1 + + logger.debug("Removed %s no-op reshapes and slices", count) + self.dump_graph(graph, "after_noop_elimination") + self.end_and_log() + + def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]], + i_dims: Iterable[Union[int, SymInt]]): + return all( + self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape/slice + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py new file mode 100644 index 0000000..420b86b --- /dev/null +++ b/vllm/compilation/pass_manager.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from torch import fx as fx + +from vllm.config import VllmConfig +from vllm.logger import init_logger + +# from .activation_quant_fusion import ActivationQuantFusionPass +from .collective_fusion import AsyncTPPass +from .fix_functionalization import FixFunctionalizationPass +from .fusion import FusionPass +from .fusion_attn import AttnFusionPass +from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context +from .noop_elimination import NoOpEliminationPass +from .sequence_parallelism import SequenceParallelismPass +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class PostGradPassManager(CustomGraphPass): + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It supports uuid for the Inductor code cache. That includes torch<2.6 + support using pickling (in .inductor_pass.CustomGraphPass). + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (NoopEliminationPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: list[VllmInductorPass] = [] + + def __call__(self, graph: fx.Graph): + shape = get_pass_context().runtime_shape + for pass_ in self.passes: + if pass_.is_applicable_for_shape(shape): + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + if self.pass_config.enable_noop: + self.passes += [NoOpEliminationPass(config)] + + if self.pass_config.enable_sequence_parallelism: + self.passes += [SequenceParallelismPass(config)] + if self.pass_config.enable_async_tp: + self.passes += [AsyncTPPass(config)] + + # if self.pass_config.enable_fusion: + # self.passes += [FusionPass.instance(config)] + # self.passes += [ActivationQuantFusionPass(config)] + + if self.pass_config.enable_attn_fusion: + self.passes += [AttnFusionPass(config)] + + self.fix_functionalization = FixFunctionalizationPass(config) + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def uuid(self): + """ + The PostGradPassManager is set as a custom pass in the Inductor and + affects compilation caching. Its uuid depends on the UUIDs of all + dependent passes and the pass config. See InductorPass for more info. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return InductorPass.hash_dict(state) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py new file mode 100644 index 0000000..29305a8 --- /dev/null +++ b/vllm/compilation/sequence_parallelism.py @@ -0,0 +1,482 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch._inductor.pattern_matcher as pm +import torch.fx as fx +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig +from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class _RMSNormAndQuantOpHelper: + """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" + + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs): + self.epsilon = epsilon + self.dtype = dtype + self.device = device + self.quant_op = quant_op + + def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.rms_norm.default, + result=result_buffer, + input=input_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, + weight_tensor): + return torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=input_tensor, + residual=residual_tensor, + weight=weight_tensor, + epsilon=self.epsilon) + + def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, + quant_result_buffer, input_tensor, + weight_tensor, scale_tensor): + if self.quant_op is None: + raise RuntimeError( + "_RMSNormAndQuantOpHelper was not initialized with a quant_op." + ) + rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer, + input_tensor, + weight_tensor) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple + + def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, + input_tensor, residual_tensor, + weight_tensor, scale_tensor): + if self.quant_op is None: + raise RuntimeError( + "_RMSNormAndQuantOpHelper was not initialized with a quant_op." + ) + fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( + input_tensor, residual_tensor, weight_tensor) + quant_out_tuple = torch.ops.higher_order.auto_functionalized( + self.quant_op, + result=quant_result_buffer, + input=fused_add_rmsnorm_out_tuple[1], + scale=scale_tensor) + return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] + + +class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): + """Helper for sequence parallelism patterns.""" + + def __init__(self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs): + super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) + self.tp_group = get_tp_group() + self.tp_size = get_tensor_model_parallel_world_size() + + def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return tensor_model_parallel_all_reduce(x) + + def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.reduce_scatter.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + def _all_gather(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.vllm.all_gather.default( + x, + dim=0, + world_size=self.tp_size, + group_name=self.tp_group.unique_name) + + +class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + + def get_inputs(self): + input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype) + + return [input, permute, arg3_1] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + input: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + all_reduce = self._all_reduce(input) + rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1) + + return rmsnorm[1], all_reduce + + def replacement( + input: torch.Tensor, + permute: torch.Tensor, + arg3_1: torch.Tensor, + ): + reduce_scatter = self._reduce_scatter(input) + + rmsnorm_result = torch.empty_like(reduce_scatter) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, + arg3_1) + + all_gather = self._all_gather(rmsnorm[1]) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) + return rmsnorm[1], rmsnorm[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + all_gather = self._all_gather(rmsnorm[1]) + return all_gather, rmsnorm[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + + return [ + residual, + mm_1, + rms_norm_weights, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + all_reduce, residual, rms_norm_weights) + return rmsnorm[1] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + rmsnorm = self._functional_fused_add_rmsnorm( + reduce_scatter, residual, rms_norm_weights) + normalized = self._all_gather(rmsnorm[1]) + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +FP8_DTYPE = current_platform.fp8_dtype() + + +class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) + rmsnorm_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], + device=self.device, + dtype=FP8_DTYPE) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) + return [input, rmsnorm_result, quant_result, weight, scale] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + all_reduce = self._all_reduce(input) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, all_reduce, weight, scale) + return static_fp8[1], all_reduce + + def replacement( + input: torch.Tensor, + rmsnorm_result: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + reduce_scatter = self._reduce_scatter(input) + + rmsnorm_result = torch.empty_like(reduce_scatter, + dtype=rmsnorm_result.dtype) + quant_result = torch.empty_like( + rmsnorm_result, # Output of RMSNorm + dtype=quant_result.dtype) + static_fp8 = self._functional_rmsnorm_then_quant( + rmsnorm_result, quant_result, reduce_scatter, weight, scale) + all_gather = self._all_gather(static_fp8[1]) + + return all_gather, reduce_scatter + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1], rmsnorm_residual_out + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + all_gather = self._all_gather(static_fp8[1]) + return all_gather, rmsnorm_residual_out + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): + + def __init__(self, epsilon: float, dtype: torch.dtype, device: str, + op: torch._ops.OpOverload): + super().__init__(epsilon, dtype, device, quant_op=op) + + def get_inputs(self): + mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) + + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], + device=self.device, + dtype=self.dtype) + result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) + + return [ + result, + residual, + mm_1, + rms_norm_weights, + scale, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + result, all_reduce, residual, rms_norm_weights, scale) + return static_fp8[1] + + def replacement( + result: torch.Tensor, + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(mm_1) + quant_result_buf = torch.empty_like(reduce_scatter, + dtype=result.dtype) + static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( + quant_result_buf, reduce_scatter, residual, rms_norm_weights, + scale) + normalized = self._all_gather(static_fp8[1]) + return normalized + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class SequenceParallelismPass(VllmInductorPass): + """ + This pass enables sequence parallelism for models. + It identifies patterns where an AllReduce operation is followed by + an RMSNorm (or RMSNorm and then Quantization) operation. + These patterns are replaced with a ReduceScatter operation, followed by + a local RMSNorm/Quantization, and then an AllGather operation. + + The general transformation is: + Input -> AllReduce -> RMSNorm -> Output + becomes + Input -> ReduceScatter -> RMSNorm -> AllGather -> Output + + While this pass itself does not directly yield performance improvements, + it lays the groundwork for subsequent fusion passes, such as + GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can + significantly reduce communication overhead and improve overall model + performance. + """ + + def __init__(self, config: VllmConfig): + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="sequence_parallelism_pass") + + for epsilon in [1e-5, 1e-6]: + # RMSNorm + Static FP8 quantization patterns + # fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default + # FirstAllReduceRMSNormStaticFP8Pattern( + # epsilon, self.model_dtype, self.device, + # fp8_quant_op).register(self.patterns) + # MiddleAllReduceRMSNormStaticFP8Pattern( + # epsilon, self.model_dtype, self.device, + # fp8_quant_op).register(self.patterns) + # LastAllReduceRMSNormStaticFP8Pattern( + # epsilon, self.model_dtype, self.device, + # fp8_quant_op).register(self.patterns) + + # Normal RMSNorm patterns + FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) + + MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) + + LastAllReduceRMSNormPattern(epsilon, self.model_dtype, + self.device).register(self.patterns) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def is_applicable_for_shape(self, shape: Optional[int]) -> bool: + tp_size = get_tensor_model_parallel_world_size() + return shape is not None and shape % tp_size == 0 + + def __call__(self, graph: fx.Graph): + self.begin() + self.dump_graph(graph, "before_sequence_parallelism_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_sequence_parallelism_pass") + self.end_and_log() diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py new file mode 100644 index 0000000..cd39706 --- /dev/null +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch + + +class Torch25CustomGraphPass(ABC): # noqa (redefinition) + """ + This class replaces CustomGraphPass from torch==2.6 when using torch<2.6. + It conforms to the 2.6 interface but also supports pickling, as that's what + the inductor code cache uses to determine the cache key before 2.6. + (in 2.6 and above, uuid() is used.) + + Subclasses can just "pretend" that uuid is used. + """ + + @abstractmethod + def __call__(self, graph: torch.fx.graph.Graph) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. + Return None to skip inductor code caching entirely. + """ + + def __getstate__(self): + """ + Pickling is used instead of uuid() in torch<2.6. Just return uuid() + to enable subclasses to only have to implement uuid. + """ + return self.uuid() + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CustomGraphPass because pickling" + " is used for cache key uuid. Use torch>=2.6 with" + " native uuid support for custom passes.") diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000..628e9e2 --- /dev/null +++ b/vllm/compilation/vllm_inductor_pass.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time + +import torch +from torch._dynamo.utils import lazy_format_graph_code + +from vllm.config import PassConfig, VllmConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: VllmConfig): + self.pass_config = config.compilation_config.pass_config + self.model_dtype = config.model_config.dtype if config.model_config \ + else None + self.device = config.device_config.device if config.device_config \ + else None + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + lazy_format_graph_code(stage, graph.owning_module) + + if stage in self.pass_config.dump_graph_stages or always: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) + + +class PrinterInductorPass(VllmInductorPass): + + def __init__(self, name: str, config: PassConfig, always=False): + super().__init__(config) + self.name = name + self.always = always + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, self.name, always=self.always) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 0000000..2a261c8 --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, Optional + +import torch + +import vllm.envs as envs +from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TorchCompileWrapperWithCustomDispatcher: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): + + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + backend = vllm_config.compilation_config.init_backend(vllm_config) + options = None + if isinstance(backend, str) and backend == "inductor": + options = get_current_vllm_config( + ).compilation_config.inductor_compile_config + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend, + options=options) + + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: list[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + compilation_level >= CompilationLevel.DYNAMO_ONCE + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while frame and frame.f_back: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + local_cache_dir = self.vllm_config.compilation_config.local_cache_dir + if isinstance(local_cache_dir, str): + decompiled_file = os.path.join(local_cache_dir, + "transformed_code.py") + if not os.path.exists(decompiled_file): + try: + # usually the decompilation will succeed for most models, + # as we guarantee a full-graph compilation in Dynamo. + # but there's no 100% guarantee, since decompliation is + # not a reversible process. + import depyf + src = depyf.decompile(new_code) + with open(decompiled_file, "w") as f: + f.write(src) + + logger.debug("Dynamo transformed code saved to %s", + decompiled_file) + except Exception: + pass + + if self.vllm_config.compilation_config.use_cudagraph and \ + "update" in new_code.co_names: + import depyf + src = depyf.decompile(new_code) + msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + raise RuntimeError(msg) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config.py b/vllm/config.py new file mode 100644 index 0000000..63c9ca7 --- /dev/null +++ b/vllm/config.py @@ -0,0 +1,4953 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import ast +import copy +import enum +import hashlib +import inspect +import json +import textwrap +import uuid +import warnings +from collections import Counter +from contextlib import contextmanager +from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, + replace) +from functools import cached_property +from importlib.util import find_spec +from pathlib import Path +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, List, + Protocol, TypeVar, Union, cast, get_args, get_origin) + +import regex as re +import torch +from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator, + model_validator) +from pydantic.dataclasses import dataclass +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from torch.distributed import ProcessGroup, ReduceOp +from typing_extensions import Self, deprecated, runtime_checkable + +import vllm.envs as envs +from vllm import version +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config, is_encoder_decoder, + try_get_generation_config, try_get_safetensors_metadata, + try_get_tokenizer_config, uses_mrope) +from vllm.transformers_utils.s3_utils import S3Model +from vllm.transformers_utils.utils import is_s3, maybe_model_redirect + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, + LayerBlockType, LazyLoader, common_broadcastable_dtype, + cuda_device_count_stateless, get_cpu_memory, + get_open_port, is_torch_equal_or_newer, random_uuid, + resolve_obj_by_qualname) +from vllm.utils import SUPPORT_TC + +# yapf: enable + +if TYPE_CHECKING: + from _typeshed import DataclassInstance + from ray.util.placement_group import PlacementGroup + from transformers.configuration_utils import PretrainedConfig + + import vllm.model_executor.layers.quantization as me_quant + import vllm.model_executor.models as me_models + from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + from vllm.model_executor.model_loader import BaseModelLoader + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + + ConfigType = type[DataclassInstance] + HfOverrides = Union[dict, Callable[[type], type]] +else: + PlacementGroup = Any + PretrainedConfig = Any + ExecutorBase = Any + QuantizationConfig = Any + QuantizationMethods = Any + BaseModelLoader = Any + TensorizerConfig = Any + ConfigType = type + HfOverrides = Union[dict[str, Any], Callable[[type], type]] + + me_quant = LazyLoader("model_executor", globals(), + "vllm.model_executor.layers.quantization") + me_models = LazyLoader("model_executor", globals(), + "vllm.model_executor.models") + +logger = init_logger(__name__) + +models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH") + +ConfigT = TypeVar("ConfigT", bound=ConfigType) + +TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", + "score", "reward", "transcription"] + +_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft", + "transcription"] + +RunnerType = Literal["generate", "pooling", "draft", "transcription"] + +_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { + "generate": ["generate"], + "pooling": ["embed", "classify", "reward"], + "draft": ["draft"], + "transcription": ["transcription"], +} + +_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { + task: runner + for runner, tasks in _RUNNER_TASKS.items() + for task in tasks +} + + +@runtime_checkable +class SupportsHash(Protocol): + + def compute_hash(self) -> str: + ... + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> dict[str, str]: + ... + + +class ModelImpl(str, enum.Enum): + AUTO = "auto" + VLLM = "vllm" + TRANSFORMERS = "transformers" + + +def get_attr_docs(cls: type[Any]) -> dict[str, str]: + """ + Get any docstrings placed after attribute assignments in a class body. + + https://davidism.com/mit-license/ + """ + + def pairwise(iterable): + """ + Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise + + Can be removed when Python 3.9 support is dropped. + """ + iterator = iter(iterable) + a = next(iterator, None) + + for b in iterator: + yield a, b + a = b + + cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] + + if not isinstance(cls_node, ast.ClassDef): + raise TypeError("Given object was not a class.") + + out = {} + + # Consider each pair of nodes. + for a, b in pairwise(cls_node.body): + # Must be an assignment then a constant string. + if (not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str)): + continue + + doc = inspect.cleandoc(b.value.value) + + # An assignment can have multiple targets (a = b = v), but an + # annotated assignment only has one target. + targets = a.targets if isinstance(a, ast.Assign) else [a.target] + + for target in targets: + # Must be assigning to a plain name. + if not isinstance(target, ast.Name): + continue + + out[target.id] = doc + + return out + + +def config(cls: ConfigT) -> ConfigT: + """ + A decorator that ensures all fields in a dataclass have default values + and that each field has a docstring. + + If a `ConfigT` is used as a CLI argument itself, the default value provided + by `get_kwargs` will be the result parsing a JSON string as the kwargs + (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` + requires custom construction from CLI (i.e. `CompilationConfig`), it can + have a `from_cli` method, which will be called instead. + + Config validation is performed by the tools/validate_config.py + script, which is invoked during the pre-commit checks. + """ + return cls + + +def get_field(cls: ConfigType, name: str) -> Field: + """Get the default factory field of a dataclass by name. Used for getting + default factory fields in `EngineArgs`.""" + if not is_dataclass(cls): + raise TypeError("The given class is not a dataclass.") + cls_fields = {f.name: f for f in fields(cls)} + if name not in cls_fields: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") + named_field: Field = cls_fields[name] + if (default_factory := named_field.default_factory) is not MISSING: + return field(default_factory=default_factory) + if (default := named_field.default) is not MISSING: + return field(default=default) + raise ValueError( + f"{cls.__name__}.{name} must have a default value or default factory.") + + +def is_init_field(cls: ConfigType, name: str) -> bool: + return next(f for f in fields(cls) if f.name == name).init + + +TokenizerMode = Literal["auto", "cpm", "slow", "mistral", "custom"] +ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class ModelConfig: + """Configuration for the model.""" + + model: str = os.path.join(models_path_prefix, 'facebook/opt-125m') if models_path_prefix is not None else 'facebook/opt-125m' + """Name or path of the Hugging Face model to use. It is also used as the + content for `model_name` tag in metrics output when `served_model_name` is + not specified.""" + task: Literal[TaskOption, Literal["draft"]] = "auto" + """The task to use the model for. Each vLLM instance only supports one + task, even if the same model can be used for multiple tasks. When the model + only supports one task, "auto" can be used to select it; otherwise, you + must specify explicitly which task to use.""" + tokenizer: SkipValidation[str] = None # type: ignore + """Name or path of the Hugging Face tokenizer to use. If unspecified, model + name or path will be used.""" + tokenizer_mode: TokenizerMode = "auto" + """Tokenizer mode:\n + - "auto" will use the fast tokenizer if available.\n + - "slow" will always use the slow tokenizer.\n + - "mistral" will always use the tokenizer from `mistral_common`.\n + - "custom" will use --tokenizer to select the preregistered tokenizer.""" + trust_remote_code: bool = False + """Trust remote code (e.g., from HuggingFace) when downloading the model + and tokenizer.""" + dtype: Union[ModelDType, torch.dtype] = "auto" + """Data type for model weights and activations:\n + - "auto" will use FP16 precision for FP32 and FP16 models, and BF16 + precision for BF16 models.\n + - "half" for FP16. Recommended for AWQ quantization.\n + - "float16" is the same as "half".\n + - "bfloat16" for a balance between precision and range.\n + - "float" is shorthand for FP32 precision.\n + - "float32" for FP32 precision.""" + seed: Optional[int] = None + """Random seed for reproducibility. Initialized to None in V0, but + initialized to 0 in V1.""" + hf_config_path: Optional[str] = None + """Name or path of the Hugging Face config to use. If unspecified, model + name or path will be used.""" + allowed_local_media_path: str = "" + """Allowing API requests to read local images or videos from directories + specified by the server file system. This is a security risk. Should only + be enabled in trusted environments.""" + revision: Optional[str] = None + """The specific model version to use. It can be a branch name, a tag name, + or a commit id. If unspecified, will use the default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the model code on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + rope_scaling: dict[str, Any] = field(default_factory=dict) + """RoPE scaling configuration. For example, + `{"rope_type":"dynamic","factor":2.0}`.""" + rope_theta: Optional[float] = None + """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE + theta improves the performance of the scaled model.""" + tokenizer_revision: Optional[str] = None + """The specific revision to use for the tokenizer on the Hugging Face Hub. + It can be a branch name, a tag name, or a commit id. If unspecified, will + use the default version.""" + max_model_len: SkipValidation[int] = None # type: ignore + """Model context length (prompt and output). If unspecified, will be + automatically derived from the model config. + + When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable + format. Examples:\n + - 1k -> 1000\n + - 1K -> 1024\n + - 25.6k -> 25,600""" + spec_target_max_model_len: Optional[int] = None + """Specify the maximum length for spec decoding draft models.""" + quantization: SkipValidation[Optional[QuantizationMethods]] = None + """Method used to quantize the weights. If `None`, we first check the + `quantization_config` attribute in the model config file. If that is + `None`, we assume the model weights are not quantized and use `dtype` to + determine the data type of the weights.""" + enforce_eager: bool = False + """Whether to always use eager-mode PyTorch. If True, we will disable CUDA + graph and always execute the model in eager mode. If False, we will use + CUDA graph and eager execution in hybrid for maximal performance and + flexibility.""" + max_seq_len_to_capture: Optional[int] = None # 8192 + """Maximum sequence len covered by CUDA graphs. When a sequence has context + length larger than this, we fall back to eager mode. Additionally for + encoder-decoder models, if the sequence length of the encoder input is + larger than this, we fall back to the eager mode.""" + max_logprobs: int = 20 + """Maximum number of log probabilities to return when `logprobs` is + specified in `SamplingParams`. The default value comes the default for the + OpenAI Chat Completions API.""" + disable_sliding_window: bool = False + """Whether to disable sliding window. If True, we will disable the sliding + window functionality of the model, capping to sliding window size. If the + model does not support sliding window, this argument is ignored.""" + disable_cascade_attn: bool = False + """Disable cascade attention for V1. While cascade attention does not + change the mathematical correctness, disabling it could be useful for + preventing potential numerical issues. Note that even if this is set to + False, cascade attention will be only used when the heuristic tells that + it's beneficial.""" + skip_tokenizer_init: bool = False + """Skip initialization of tokenizer and detokenizer. Expects valid + `prompt_token_ids` and `None` for prompt from the input. The generated + output will contain token ids.""" + enable_prompt_embeds: bool = False + """If `True`, enables passing text embeddings as inputs via the + `prompt_embeds` key. Note that enabling this will double the time required + for graph compilation.""" + served_model_name: Optional[Union[str, list[str]]] = None + """The model name(s) used in the API. If multiple names are provided, the + server will respond to any of the provided names. The model name in the + model field of a response will be the first name in this list. If not + specified, the model name will be the same as the `--model` argument. Noted + that this name(s) will also be used in `model_name` tag content of + prometheus metrics, if multiple names provided, metrics tag will take the + first one.""" + limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) + """Maximum number of data items per modality per prompt. Only applicable + for multimodal models.""" + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set + `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ + use_async_output_proc: bool = True + """Whether to use async output processor.""" + config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + """The format of the model config to load:\n + - "auto" will try to load the config in hf format if available else it + will try to load in mistral format.\n + - "hf" will load the config in hf format.\n + - "mistral" will load the config in mistral format.""" + hf_token: Optional[Union[bool, str]] = None + """The token to use as HTTP bearer authorization for remote files . If + `True`, will use the token generated when running `huggingface-cli login` + (stored in `~/.huggingface`).""" + hf_overrides: HfOverrides = field(default_factory=dict) + """If a dictionary, contains arguments to be forwarded to the Hugging Face + config. If a callable, it is called to update the HuggingFace config.""" + mm_processor_kwargs: Optional[dict[str, Any]] = None + """Arguments to be forwarded to the model's processor for multi-modal data, + e.g., image processor. Overrides for the multi-modal processor obtained + from `AutoProcessor.from_pretrained`. The available overrides depend on the + model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. + """ + disable_mm_preprocessor_cache: bool = False + """If `True`, disable caching of the multi-modal preprocessor/mapper (not + recommended).""" + override_neuron_config: dict[str, Any] = field(default_factory=dict) + """Initialize non-default neuron config or override default neuron config + that are specific to Neuron devices, this argument will be used to + configure the neuron config that can not be gathered from the vllm + arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`.""" + pooler_config: Optional["PoolerConfig"] = field(init=False) + """Pooler config which controls the behaviour of output pooling in pooling + models.""" + override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None + """Initialize non-default pooling config or override default pooling config + for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. + """ + logits_processor_pattern: Optional[str] = None + """Optional regex pattern specifying valid logits processor qualified names + that can be passed with the `logits_processors` extra completion argument. + Defaults to `None`, which allows no processors.""" + generation_config: str = "auto" + """The folder path to the generation config. Defaults to `"auto"`, the + generation config will be loaded from model path. If set to `"vllm"`, no + generation config is loaded, vLLM defaults will be used. If set to a folder + path, the generation config will be loaded from the specified folder path. + If `max_new_tokens` is specified in generation config, then it sets a + server-wide limit on the number of output tokens for all requests.""" + override_generation_config: dict[str, Any] = field(default_factory=dict) + """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If + used with `--generation-config auto`, the override parameters will be + merged with the default config from the model. If used with + `--generation-config vllm`, only the override parameters are used.""" + enable_sleep_mode: bool = False + """Enable sleep mode for the engine (only cuda platform is supported).""" + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value + """Which implementation of the model to use:\n + - "auto" will try to use the vLLM implementation, if it exists, and fall + back to the Transformers implementation if no vLLM implementation is + available.\n + - "vllm" will use the vLLM model implementation.\n + - "transformers" will use the Transformers model implementation.""" + override_attention_dtype: Optional[str] = None + """Override dtype for attention""" + + enable_chunked_prefill: Optional[bool] = None + """If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.model) + factors.append(self.dtype) + factors.append(self.quantization) + factors.append(self.revision) + factors.append(self.code_revision) + factors.append(self.max_model_len) + factors.append(self.max_logprobs) + factors.append(self.disable_sliding_window) + factors.append(self.trust_remote_code) + factors.append(self.generation_config) + factors.append(self.model_impl) + factors.append(self.override_generation_config) + factors.append(self.rope_scaling) + factors.append(self.rope_theta) + # hf_config can control how the model looks! + factors.append(self.hf_config.to_json_string()) + factors.append(self.enable_chunked_prefill) + str_factors = str(factors) + assert_hashable(str_factors) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + # Set the default seed to 0 in V1. + # NOTE(woosuk): In V0, we set the default seed to None because the + # driver worker shares the same process as the user process, and thus + # setting a seed affects the user process as well. + # In V1, we use separate processes for workers (unless + # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here + # doesn't affect the user process. However, without a consistent seed, + # different tensor parallel workers would sample different tokens, + # leading to inconsistent results. + if envs.VLLM_USE_V1 and self.seed is None: + self.seed = 0 + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.warning( + "The global random seed is set to %d. Since " + "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " + "affect the random state of the Python process that " + "launched vLLM.", self.seed) + + # Keep set served_model_name before maybe_model_redirect(self.model) + self.served_model_name = get_served_model_name(self.model, + self.served_model_name) + self.model = maybe_model_redirect(self.model) + # The tokenizer is consistent with the model by default. + if self.tokenizer is None: + self.tokenizer = self.model + if self.tokenizer_revision is None: + self.tokenizer_revision = self.revision + self.tokenizer = maybe_model_redirect(self.tokenizer) + + if isinstance(self.hf_config_path, str): + self.hf_config_path = maybe_model_redirect(self.hf_config_path) + + if callable(self.hf_overrides): + hf_overrides_kw = {} + hf_overrides_fn = self.hf_overrides + else: + hf_overrides_kw = self.hf_overrides + hf_overrides_fn = None + + if self.rope_scaling: + hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides_kw) + msg = ( + "`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + if self.rope_theta is not None: + hf_override = {"rope_theta": self.rope_theta} + hf_overrides_kw.update(hf_override) + hf_overrides_str = json.dumps(hf_overrides_kw) + msg = ( + "`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer) + + if (backend := envs.VLLM_ATTENTION_BACKEND + ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: + raise ValueError( + "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " + "module was not found. See " + "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 + "for instructions on how to install it.") + + from vllm.platforms import current_platform + + if (self.override_attention_dtype is not None + and not current_platform.is_rocm()): + warnings.warn( + "override-attention-dtype is set but not using ROCm platform", + stacklevel=2) + + if (self.enable_sleep_mode + and not current_platform.is_sleep_mode_available()): + raise ValueError( + "Sleep mode is not supported on current platform.") + + if isinstance(self.config_format, str): + self.config_format = ConfigFormat(self.config_format) + + hf_config = get_config(self.hf_config_path or self.model, + self.trust_remote_code, self.revision, + self.code_revision, self.config_format) + + if hf_overrides_kw: + logger.debug("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.debug("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + + self.hf_config = hf_config + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.attention_chunk_size = getattr(self.hf_text_config, + "attention_chunk_size", None) + self.encoder_config = self._get_encoder_config() + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, hf_token=self.hf_token, revision=self.revision) + + supported_tasks, task = self._resolve_task(self.task) + self.supported_tasks = supported_tasks + self.task = task + if self.task in ("draft", "generate"): + self.truncation_side = "left" + else: + self.truncation_side = "right" + + model_info, arch = self.registry.inspect_model_cls(self.architectures) + self._model_info = model_info + self._architecture = arch + + self.pooler_config = self._init_pooler_config() + + self.dtype = _get_and_verify_dtype( + self.model, + self.hf_config, + self.dtype, + is_pooling_model=self.runner_type == "pooling", + revision=self.revision, + ) + + # Workaround for Gemma 2 which uses interleaved sliding window + # attention, but it's not specified in its config. TODO: remove this + # when Gemma 2 is fixed in Transformers. + if self.hf_text_config.model_type == "gemma2": + self.hf_text_config.sliding_window_pattern = 2 + + sliding_window = getattr(self.hf_text_config, "sliding_window", None) + sliding_window_pattern = getattr(self.hf_text_config, + "sliding_window_pattern", None) + has_interleaved_attention = sliding_window_pattern is not None or ( + isinstance(sliding_window, list)) + + if not self.disable_sliding_window and has_interleaved_attention: + if (backend := + envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + logger.warning_once( + "%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501 + self.hf_text_config.model_type, + backend, + sliding_window_len_min, + ) + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + + if hasattr(self.hf_text_config, "sliding_window"): + delattr(self.hf_text_config, "sliding_window") + + sliding_window = None + + self.original_max_model_len = self.max_model_len + self.max_model_len = self.get_and_verify_max_len(self.max_model_len) + self.multimodal_config = self._init_multimodal_config() + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + + self.is_attention_free = self._init_attention_free() + self.is_hybrid = self._init_is_hybrid() + self.has_noops = self._init_has_noops() + self.has_inner_state = self._init_has_inner_state() + + if (not current_platform.is_neuron() and self.override_neuron_config): + raise ValueError( + "`override_neuron_config` is only supported on Neuron.") + + self._verify_quantization() + self._verify_cuda_graph() + self._verify_bnb_config() + + @field_validator("quantization", mode="before") + @classmethod + def validate_quantization_before(cls, value: Any) -> Any: + if isinstance(value, str): + return value.lower() + return value + + @model_validator(mode="after") + def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + if not isinstance(self.tokenizer, str): + raise ValueError("tokenizer must be a string after __post_init__.") + if not isinstance(self.max_model_len, int): + raise ValueError( + "max_model_len must be an integer after __post_init__.") + return self + + @property + def registry(self): + return me_models.ModelRegistry + + @property + def architectures(self) -> list[str]: + # architectures in the model config. + return getattr(self.hf_config, "architectures", []) + + @property + def architecture(self) -> str: + # The architecture vllm actually used. + return self._architecture + + @property + def model_info(self): + return self._model_info + + def maybe_pull_model_tokenizer_for_s3(self, model: str, + tokenizer: str) -> None: + """Pull model/tokenizer from S3 to temporary directory when needed. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + """ + if not (is_s3(model) or is_s3(tokenizer)): + return + + if is_s3(model): + s3_model = S3Model() + s3_model.pull_files(model, + allow_pattern=["*.model", "*.py", "*.json"]) + self.model_weights = model + self.model = s3_model.dir + + # If tokenizer is same as model, download to same directory + if model == tokenizer: + s3_model.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = s3_model.dir + return + + # Only download tokenizer if needed and not already handled + if is_s3(tokenizer): + s3_tokenizer = S3Model() + s3_tokenizer.pull_files( + model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + self.tokenizer = s3_tokenizer.dir + + def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: + if self.registry.is_multimodal_model(self.architectures): + return MultiModalConfig( + limit_per_prompt=self.limit_mm_per_prompt, + media_io_kwargs=self.media_io_kwargs, + mm_processor_kwargs=self.mm_processor_kwargs, + disable_mm_preprocessor_cache=self. + disable_mm_preprocessor_cache) + + if self.limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + if self.mm_processor_kwargs: + raise ValueError("`mm_processor_kwargs` is only supported for " + "multimodal models.") + if self.disable_mm_preprocessor_cache: + raise ValueError("`disable_mm_preprocessor_cache` is only " + "supported for multimodal models.") + + return None + + def _get_encoder_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + + def _init_pooler_config(self) -> Optional["PoolerConfig"]: + if self.runner_type == "pooling": + if isinstance(self.override_pooler_config, dict): + self.override_pooler_config = PoolerConfig( + **self.override_pooler_config) + + pooler_config = self.override_pooler_config or PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(pooler_config, k) is None: + setattr(pooler_config, k, v) + + if self.is_matryoshka: + if pooler_config.normalize is None: + pooler_config.normalize = True + elif not pooler_config.normalize: + raise ValueError( + "`normalize` must be enabled (set to True) " + "for models that are compatible with " + "Matryoshka Representation.") + + return pooler_config + + return None + + def _init_attention_free(self) -> bool: + return self.registry.is_attention_free_model(self.architectures) + + def _init_is_hybrid(self) -> bool: + return self.registry.is_hybrid_model(self.architectures) + + def _init_has_noops(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return self.registry.is_noops_model(architectures) + + def _init_has_inner_state(self) -> bool: + return self.registry.model_has_inner_state(self.architectures) + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) + if tokenizer_mode not in get_args(TokenizerMode): + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + f"one of {get_args(TokenizerMode)}.") + self.tokenizer_mode = tokenizer_mode + + def _get_preferred_task( + self, + architectures: list[str], + supported_tasks: set[_ResolvedTask], + ) -> Optional[_ResolvedTask]: + model_id = self.model + if get_pooling_config(model_id, self.revision): + return "embed" + if self.registry.is_cross_encoder_model(architectures): + return "classify" + if self.registry.is_transcription_model(architectures): + return "transcription" + + suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ + # Other models follow this pattern + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ForSequenceClassification", "classify"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embed"), + ("RewardModel", "reward"), + ] + _, arch = self.registry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + return pref_task + + return None + + def _resolve_task( + self, + task_option: Literal[TaskOption, Literal["draft"]], + ) -> tuple[set[_ResolvedTask], _ResolvedTask]: + if task_option == "draft": + return {"draft"}, "draft" + + registry = self.registry + architectures = self.architectures + + runner_support: dict[RunnerType, bool] = { + # NOTE: Listed from highest to lowest priority, + # in case the model supports multiple of them + "transcription": registry.is_transcription_model(architectures), + "generate": registry.is_text_generation_model(architectures), + "pooling": registry.is_pooling_model(architectures), + } + supported_runner_types_lst: list[RunnerType] = [ + runner_type + for runner_type, is_supported in runner_support.items() + if is_supported + ] + + supported_tasks_lst: list[_ResolvedTask] = [ + task for runner_type in supported_runner_types_lst + for task in _RUNNER_TASKS[runner_type] + ] + supported_tasks = set(supported_tasks_lst) + + if task_option == "auto": + selected_task = next(iter(supported_tasks_lst)) + + if len(supported_tasks_lst) > 1: + preferred_task = self._get_preferred_task( + architectures, supported_tasks) + if preferred_task is not None: + selected_task = preferred_task + + logger.info( + "This model supports multiple tasks: %s. " + "Defaulting to '%s'.", supported_tasks, selected_task) + else: + if task_option == "score": + if not runner_support["pooling"]: + msg = (f"This model does not support the '{task_option}' " + f"task. Supported tasks: {supported_tasks}") + raise ValueError(msg) + if self.registry.is_cross_encoder_model(architectures): + task_option = "classify" + else: + task_option = "embed" + else: + # Aliases + if task_option == "embedding": + msg = ("The 'embedding' task has been renamed to " + "'embed', please use the new name. The old name " + "will be removed in v1.0.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + task_option = "embed" + + if task_option not in supported_tasks: + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}") + raise ValueError(msg) + + selected_task = task_option + + return supported_tasks, selected_task + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = me_quant.QUANTIZATION_METHODS + optimized_quantization_methods = [ + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", + "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", + "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", + "slimquant_w4a8","slimquant_w4a8_marlin" + ] + if self.quantization is not None: + self.quantization = cast(me_quant.QuantizationMethods, + self.quantization) + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + quant_method = quant_method.replace("compressed_tensors", + "compressed-tensors") + quant_cfg["quant_method"] = quant_method + + # Quantization methods which are overrides (i.e. they have a + # `override_quantization_method` method) must be checked in order + # of preference (this is particularly important for GPTQ). + overrides = [ + "marlin", + "bitblas", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "ipex", + "moe_wna16", + "slimquant_w4a8_marlin" + ] + quantization_methods = [ + q for q in supported_quantization if q not in overrides + ] + # Any custom overrides will be in quantization_methods so we place + # them at the start of the list so custom overrides have preference + # over the built in ones. + quantization_methods = quantization_methods + overrides + + # Detect which checkpoint is it + for name in quantization_methods: + method = me_quant.get_quantization_config(name) + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override is not None: + # Raise error if the override is not custom (custom would + # be in QUANTIZATION_METHODS but not QuantizationMethods) + # and hasn't been added to the overrides list. + if (name in get_args(me_quant.QuantizationMethods) + and name not in overrides): + raise ValueError( + f"Quantization method {name} is an override but " + "is has not been added to the `overrides` list " + "above. This is necessary to ensure that the " + "overrides are checked in order of preference.") + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + from vllm.platforms import current_platform + current_platform.verify_quantization(self.quantization) + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + + def _verify_cuda_graph(self) -> None: + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) + # self.max_seq_len_to_capture = self.max_model_len + # CUDAGraph capture not supported for enc-dec models and mllama on ROCm + ROCM_UNSUPPORTED_MODELS = ['mllama'] + unsupported_rocm = (self.hf_config.model_type + in ROCM_UNSUPPORTED_MODELS + or self.is_encoder_decoder) + + if (unsupported_rocm and not self.enforce_eager + and current_platform.is_rocm()): + logger.warning( + "CUDA graph is not supported for %s on ROCm yet, fallback " + "to eager mode.", self.hf_config.model_type) + self.enforce_eager = True + + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.46.1) with 8-bit models does not + yet support CUDA graph. + # TODO Remove this when bitsandbytes supports. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitsAndBytes 8bit yet, " + "fallback to the eager mode.") + + self.enforce_eager = True + + def _verify_with_expert_parallelism(self) -> None: + num_expert_names = [ + "moe_num_experts", # Dbrx + "num_experts", # Jamba + "n_routed_experts", # DeepSeek + "num_local_experts", # Mixtral + ] + num_experts = 0 + for name in num_expert_names: + num_experts = getattr(self.hf_text_config, name, 0) + if num_experts > 0: + break + if num_experts < 1: + raise ValueError( + "Number of experts in the model must be greater than 0 " + "when expert parallelism is enabled.") + + def verify_dual_chunk_attention_config( + self, + load_config: "LoadConfig", + ) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + from vllm.model_executor.model_loader.weight_utils import ( + get_sparse_attention_config) + sparse_attn_config = get_sparse_attention_config(self, load_config) + if sparse_attn_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_config"] = sparse_attn_config + if "sparse_attention_enabled" not in \ + self.hf_config.dual_chunk_attention_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled"] = True + + def verify_async_output_proc(self, parallel_config, speculative_config, + device_config) -> None: + if not self.use_async_output_proc: + # Nothing to check + return + + if parallel_config.pipeline_parallel_size > 1: + self.use_async_output_proc = False + return + + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + from vllm.platforms import current_platform + if not current_platform.is_async_output_supported(self.enforce_eager): + self.use_async_output_proc = False + return + + if envs.VLLM_USE_RAY_SPMD_WORKER: + self.use_async_output_proc = False + return + + # Async postprocessor is not necessary for pooling models + # since there is no token generation + if self.runner_type == "pooling": + self.use_async_output_proc = False + + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + if speculative_config: + self.use_async_output_proc = False + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + + if parallel_config.distributed_executor_backend == "external_launcher": + assert self.seed is not None, ( + "Seed must be set when using external launcher backend to " + "make sure sampling results are the same across workers.") + + total_num_attention_heads = getattr(self.hf_text_config, + "num_attention_heads", 0) + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size}).") + + if parallel_config.enable_expert_parallel: + self._verify_with_expert_parallelism() + + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if pipeline_parallel_size > 1: + if not self.registry.is_pp_supported_model(self.architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") + + if self.use_async_output_proc: + self.use_async_output_proc = False + + def get_hf_config_sliding_window( + self) -> Union[Optional[int], list[Optional[int]]]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + @property + def is_deepseek_mla(self) -> bool: + if not hasattr(self.hf_text_config, "model_type"): + return False + elif self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): + return self.hf_text_config.kv_lora_rank is not None + elif self.hf_text_config.model_type == 'eagle': + # if the model is an EAGLE module, check for the + # underlying architecture + return self.hf_text_config.model.model_type in \ + ('deepseek_v2', 'deepseek_v3') \ + and self.hf_text_config.kv_lora_rank is not None + return False + + def get_head_size(self) -> int: + # TODO remove hard code + if self.is_deepseek_mla: + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", + 0) + if self.use_mla: + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim + else: + qk_nope_head_dim = getattr(self.hf_text_config, + "qk_nope_head_dim", 0) + if qk_rope_head_dim and qk_nope_head_dim: + return qk_rope_head_dim + qk_nope_head_dim + + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + return self.hf_text_config.attention_head_dim + + if self.is_attention_free: + return 0 + + # NOTE: Some configs may set head_dim=None in the config + if getattr(self.hf_text_config, "head_dim", None) is not None: + return self.hf_text_config.head_dim + + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + if self.hf_config.model_type == "nemotron-nas": + for block in self.hf_config.block_configs: + if not block.attention.no_op: + return self.hf_config.num_attention_heads \ + // block.attention.n_heads_in_group + + raise RuntimeError("Couldn't determine number of kv heads") + + if self.is_attention_free: + return 0 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + if self.use_mla: + # When using MLA during decode it becomes MQA + return 1 + + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads // parallel_config.tensor_parallel_size + + def get_layers_start_end_indices( + self, parallel_config: "ParallelConfig") -> tuple[int, int]: + from vllm.distributed.utils import get_pp_indices + if (self.hf_text_config.model_type == "deepseek_mtp" + or self.hf_config.model_type == "mimo_mtp" + or self.hf_config.model_type == "glm4_moe_mtp"): + total_num_hidden_layers = getattr(self.hf_text_config, + "num_nextn_predict_layers", 0) + else: + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + # the layout order is: DP x PP x TP + pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size + ) % parallel_config.pipeline_parallel_size + pp_size = parallel_config.pipeline_parallel_size + start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) + return start, end + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + start, end = self.get_layers_start_end_indices(parallel_config) + return end - start + + def get_num_layers_by_block_type( + self, + parallel_config: "ParallelConfig", + block_type: LayerBlockType = LayerBlockType.attention, + ) -> int: + # This function relies on 'layers_block_type' in hf_config, + # for w/o this attribute, we will need to have workarounds like so + attn_block_type = block_type == LayerBlockType.attention + is_transformer = not self.is_hybrid and \ + not self.has_noops and \ + not self.is_attention_free + start, end = self.get_layers_start_end_indices(parallel_config) + + if is_transformer: + # Handle the basic case first + return end - start if attn_block_type else 0 + elif self.is_attention_free: + # Attention free + # Note that this code assumes there + # is only one type of attention-free block type. + return 0 if attn_block_type else end - start + elif self.has_noops: + block_configs = self.hf_config.block_configs + return sum(not bc.attention.no_op + for bc in block_configs[start:end]) + else: + # Hybrid model Jamba + layers_block_type_value = getattr(self.hf_config, + "layers_block_type", None) + if layers_block_type_value is not None: + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + == "zamba2"): + if attn_block_type: + return sum(t == "hybrid" + for t in layers_block_type_value[start:end]) + else: + return self.get_num_layers(parallel_config) + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) + + # Hybrid model Minimax + attn_type_list = getattr(self.hf_config, "attn_type_list", None) + if attn_type_list: + return sum(t == 1 for t in attn_type_list[start:end]) + + if layers_block_type_value is None and attn_type_list is None: + raise ValueError( + "The model is an hybrid without a" + "layers_block_type or an attn_type_list in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") + + return sum(t == 1 for t in attn_type_list[start:end]) + + def get_multimodal_config(self) -> "MultiModalConfig": + """ + Get the multimodal configuration of the model. + + Raises: + ValueError: If the model is not multimodal. + """ + if self.multimodal_config is None: + raise ValueError("The model is not multimodal.") + + return self.multimodal_config + + def try_get_generation_config(self) -> dict[str, Any]: + if self.generation_config in ("auto", "vllm"): + config = try_get_generation_config( + self.hf_config_path or self.model, + trust_remote_code=self.trust_remote_code, + revision=self.revision, + ) + else: + config = try_get_generation_config( + self.generation_config, + trust_remote_code=self.trust_remote_code, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + def get_diff_sampling_param(self) -> dict[str, Any]: + """ + This method returns a dictionary containing the parameters + that differ from the default sampling parameters. If + `generation_config` is `"vllm"`, an empty dictionary is returned. + + Returns: + dict[str, Any]: A dictionary with the differing sampling + parameters, if `generation_config` is `"vllm"` an empty dictionary. + """ + if self.generation_config == "vllm": + config = {} + else: + config = self.try_get_generation_config() + + # Overriding with given generation config + config.update(self.override_generation_config) + + available_params = [ + "repetition_penalty", + "temperature", + "top_k", + "top_p", + "min_p", + "max_new_tokens", + ] + if any(p in config for p in available_params): + diff_sampling_param = { + p: config.get(p) + for p in available_params if config.get(p) is not None + } + # Huggingface definition of max_new_tokens is equivalent + # to vLLM's max_tokens + if "max_new_tokens" in diff_sampling_param: + diff_sampling_param["max_tokens"] = diff_sampling_param.pop( + "max_new_tokens") + else: + diff_sampling_param = {} + + if diff_sampling_param: + logger.warning_once( + "Default sampling parameters have been overridden by the " + "model's Hugging Face generation config recommended from the " + "model creator. If this is not intended, please relaunch " + "vLLM instance with `--generation-config vllm`.") + return diff_sampling_param + + @property + def is_encoder_decoder(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + """ + For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to + True to enable cross-attention + Neuron needs all multimodal data to be in the decoder and does not + need to explicitly enable cross-attention + """ + if (current_platform.is_neuron() + and self.hf_config.model_type == "mllama"): + return False + + return is_encoder_decoder(self.hf_config) + + @property + def uses_mrope(self) -> bool: + return uses_mrope(self.hf_config) + + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + + @property + def is_cross_encoder(self) -> bool: + return self.task == "classify" + + @property + def use_mla(self) -> bool: + return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE and SUPPORT_TC + + @property + def supported_runner_types(self) -> set[RunnerType]: + return {_TASK_RUNNER[task] for task in self.supported_tasks} + + @property + def runner_type(self) -> RunnerType: + return _TASK_RUNNER[cast(_ResolvedTask, self.task)] + + @property + def is_v1_compatible(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return me_models.ModelRegistry.is_v1_compatible(architectures) + + @property + def is_matryoshka(self) -> bool: + return (hasattr(self.hf_config, "matryoshka_dimensions") + or getattr(self.hf_config, "is_matryoshka", False)) + + @property + def matryoshka_dimensions(self): + return getattr(self.hf_config, "matryoshka_dimensions", None) + + def get_and_verify_max_len(self, max_model_len: int): + # For pooling models, the tokenizer's `model_max_length` is often a + # reliable source for the maximum sequence length. However, for + # generative models, this can be incorrect and unduly limit the + # context window (e.g., DeepSeek-R1). Therefore, we only consider + # tokenizer_config for pooling models. + tokenizer_config = None + if self.runner_type == "pooling": + tokenizer_config = try_get_tokenizer_config( + self.tokenizer, + trust_remote_code=self.trust_remote_code, + revision=self.tokenizer_revision) + max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + tokenizer_config=tokenizer_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=self.spec_target_max_model_len, + encoder_config=self.encoder_config) + logger.info("Using max model len %s", max_model_len) + return max_model_len + + +BlockSize = Literal[1, 8, 16, 32, 64, 128] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "int8"] +PrefixCachingHashAlgo = Literal["builtin", "sha256"] + + +@config +@dataclass +class CacheConfig: + """Configuration for the KV cache.""" + + block_size: BlockSize = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16 # type: ignore + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + + This config has no static default. If left unspecified by the user, it will + be set in `Platform.check_and_update_config()` based on the current + platform.""" + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. + """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + cpu_kvcache_space_bytes: Optional[int] = None + """(CPU backend only) CPU key-value cache space.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.cache_dtype) + # `cpu_offload_gb` does not use `torch.compile` yet. + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + + self._verify_cache_dtype() + self._verify_prefix_caching() + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + @model_validator(mode='after') + def _verify_args(self) -> Self: + if self.cpu_offload_gb < 0: + raise ValueError("CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + return self + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in get_args(CacheDType): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): + raise ValueError( + "Unknown prefix caching hash algorithm: " + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) + + +@config +@dataclass +class TokenizerPoolConfig: + """This config is deprecated and will be removed in a future release. + + Passing these parameters will have no effect. Please remove them from your + configurations. + """ + + pool_size: int = 0 + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + pool_type: str = "ray" + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + extra_config: dict = field(default_factory=dict) + """This parameter is deprecated and will be removed in a future release. + Passing this parameter will have no effect. Please remove it from your + configurations.""" + + def __post_init__(self) -> None: + logger.warning_once( + "TokenizerPoolConfig is deprecated and will be removed in a " + "future release. Passing this parameter will have no effect. " + "Please remove it from your configurations.") + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + GGUF = "gguf" + BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" + RUNAI_STREAMER = "runai_streamer" + RUNAI_STREAMER_SHARDED = "runai_streamer_sharded" + FASTSAFETENSORS = "fastsafetensors" + + +@config +@dataclass +class LoadConfig: + """Configuration for loading the model weights.""" + + load_format: Union[str, LoadFormat, + "BaseModelLoader"] = LoadFormat.AUTO.value + """The format of the model weights to load:\n + - "auto" will try to load the weights in the safetensors format and fall + back to the pytorch bin format if safetensors format is not available.\n + - "pt" will load the weights in the pytorch bin format.\n + - "safetensors" will load the weights in the safetensors format.\n + - "npcache" will load the weights in pytorch format and store a numpy cache + to speed up the loading.\n + - "dummy" will initialize the weights with random values, which is mainly + for profiling.\n + - "tensorizer" will use CoreWeave's tensorizer library for fast weight + loading. See the Tensorize vLLM Model script in the Examples section for + more information.\n + - "runai_streamer" will load the Safetensors weights using Run:ai Model + Streamer.\n + - "bitsandbytes" will load the weights using bitsandbytes quantization.\n + - "sharded_state" will load weights from pre-sharded checkpoint files, + supporting efficient loading of tensor-parallel models.\n + - "gguf" will load weights from GGUF format files (details specified in + https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n + - "mistral" will load weights from consolidated safetensors files used by + Mistral models.""" + download_dir: Optional[str] = None + """Directory to download and load the weights, default to the default + cache directory of Hugging Face.""" + model_loader_extra_config: Union[dict, TensorizerConfig] = field( + default_factory=dict) + """Extra config for model loader. This will be passed to the model loader + corresponding to the chosen load_format.""" + ignore_patterns: Optional[Union[list[str], str]] = None + """The list of patterns to ignore when loading the model. Default to + "original/**/*" to avoid repeated loading of llama's checkpoints.""" + use_tqdm_on_load: bool = True + """Whether to enable tqdm for showing progress bar when loading model + weights.""" + pt_load_map_location: Union[str, dict[str, str]] = "cpu" + """ + pt_load_map_location: the map location for loading pytorch checkpoint, to + support loading checkpoints can only be loaded on certain devices like + "cuda", this is equivalent to {"": "cuda"}. Another supported format is + mapping from different devices like from GPU 1 to GPU 0: + {"cuda:1": "cuda:0"}. Note that when passed from command line, the strings + in dictionary needs to be double quoted for json parsing. For more details, + see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if isinstance(self.load_format, str): + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + +DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] + + +@config +@dataclass +class ParallelConfig: + """Configuration for the distributed execution.""" + + pipeline_parallel_size: int = 1 + """Number of pipeline parallel groups.""" + tensor_parallel_size: int = 1 + """Number of tensor parallel groups.""" + data_parallel_size: int = 1 + """Number of data parallel groups. MoE layers will be sharded according to + the product of the tensor parallel size and data parallel size.""" + data_parallel_size_local: int = 1 + """Number of local data parallel groups.""" + data_parallel_rank: int = 0 + """Rank of the data parallel group.""" + data_parallel_rank_local: Optional[int] = None + """Local rank of the data parallel group, + set only in SPMD mode.""" + data_parallel_master_ip: str = "127.0.0.1" + """IP of the data parallel master.""" + data_parallel_rpc_port: int = 29550 + """Port for data parallel messaging.""" + data_parallel_master_port: int = 29500 + """Port of the data parallel master.""" + data_parallel_backend: str = "mp" + """Backend to use for data parallel, either "mp" or "ray".""" + data_parallel_external_lb: bool = False + """Whether to use "external" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. Set implicitly when + data_parallel_rank is provided explicitly to vllm serve.""" + enable_expert_parallel: bool = False + """Use expert parallelism instead of tensor parallelism for MoE layers.""" + enable_eplb: bool = False + """Enable expert parallelism load balancing for MoE layers.""" + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + eplb_window_size: int = 1000 + """Window size for expert load recording.""" + eplb_step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `eplb_window_size` steps will be used for rearranging experts. + """ + eplb_log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + + max_parallel_loading_workers: Optional[int] = None + """Maximum number of parallel loading workers when loading model + sequentially in multiple batches. To avoid RAM OOM when using tensor + parallel and large models.""" + + disable_custom_all_reduce: bool = False + """Disable the custom all-reduce kernel and fall back to NCCL.""" + + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + """This parameter is deprecated and will be removed in a future release. + Please remove it from your configs""" + + ray_workers_use_nsight: bool = False + """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" + + placement_group: Optional["PlacementGroup"] = None + """ray distributed model workers placement group.""" + + distributed_executor_backend: Optional[Union[DistributedExecutorBackend, + type["ExecutorBase"]]] = None + """Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If the product + of pipeline_parallel_size and tensor_parallel_size is less than + or equal to the number of GPUs available, "mp" will be used to + keep processing on a single host. Otherwise, this will default + to "ray" if Ray is installed and fail otherwise. Note that tpu + and hpu only support Ray for distributed inference.""" + + worker_cls: str = "auto" + """The full name of the worker class to use. If "auto", the worker class + will be determined based on the platform.""" + sd_worker_cls: str = "auto" + """The full name of the worker class to use for speculative decoding. + If "auto", the worker class will be determined based on the platform.""" + worker_extension_cls: str = "" + """The full name of the worker extension class to use. The worker extension + class is dynamically inherited by the worker class. This is used to inject + new attributes and methods to the worker class for use in collective_rpc + calls.""" + + world_size: int = field(init=False) + """world_size is TPxPP, it affects the number of workers we create.""" + + rank: int = 0 + """Global rank in distributed setup.""" + + enable_multimodal_encoder_data_parallel: bool = False + """ Use data parallelism instead of tensor parallelism for vision encoder. + Only support LLama4 for now""" + + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + + def get_next_dp_init_port(self) -> int: + """ + We might need to initialize process groups in multiple + processes that is related to data parallelism, + e.g. both in the worker and in the engine, which + can live in different processes. To avoid port conflicts, we + increment the port number each time we need to initialize a + new process group related to data parallelism. + """ + answer = self.data_parallel_master_port + self.data_parallel_master_port += 1 + return answer + + def stateless_init_dp_group(self) -> "ProcessGroup": + # NOTE: In high-concurrency scenarios multiple processes + # can pick the same (currently free) port through a race + # condition when calling `get_open_port()`. When the first + # process binds the port the others will subsequently fail + # with `torch.distributed.DistNetworkError: EADDRINUSE`. + # To make the initialization more robust we retry a few times + # with a fresh port whenever this specific error is observed. + from torch.distributed import DistNetworkError + + from vllm.distributed.utils import ( + stateless_init_torch_distributed_process_group) + + max_retries = 5 + last_exc: Optional[Exception] = None + for _ in range(max_retries): + try: + # use gloo since the engine process might not have cuda device + return stateless_init_torch_distributed_process_group( + self.data_parallel_master_ip, + self.get_next_dp_init_port(), + self.data_parallel_rank, + self.data_parallel_size, + backend="gloo") + except DistNetworkError as e: + # We only want to retry when the root cause is EADDRINUSE. + if "EADDRINUSE" in str(e): + logger.warning( + "Address already in use. Retrying with a new port.") + last_exc = e + continue # try again with a new port + raise e + + # If we get here all retries have failed. + assert last_exc is not None + raise last_exc + + @staticmethod + def has_unfinished_dp(dp_group: "ProcessGroup", + has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], + dtype=torch.int32, + device="cpu") + # dp rank 0: has_unfinished_seqs=True + # dp rank 1: has_unfinished_seqs=False + # aggregated: has_unfinished_seqs=True + # so this is an OR operation, i.e. MAX in integers + torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group) + aggregated_has_unfinished = bool(tensor.item()) + return aggregated_has_unfinished + + def compute_hash(self): + """ + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.pipeline_parallel_size) + factors.append(self.tensor_parallel_size) + factors.append(self.enable_expert_parallel) + factors.append(self.data_parallel_size) + factors.append(envs.VLLM_ALL2ALL_BACKEND) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") + + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: + # Data parallel was specified in the engine args. + self.data_parallel_master_port = get_open_port() + + if not (0 <= self.data_parallel_rank < self.data_parallel_size): + raise ValueError( + f"data_parallel_rank ({self.data_parallel_rank})" + f" must be in the range [0, {self.data_parallel_size})") + else: + # Otherwise fall back to env vars (e.g. for offline SPMD case). + self.data_parallel_size = envs.VLLM_DP_SIZE + self.data_parallel_rank = envs.VLLM_DP_RANK + self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL + self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP + self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT + + if self.data_parallel_external_lb: + raise ValueError("data_parallel_external_lb can only " + "be set when data_parallel_size > 1") + + if self.distributed_executor_backend == "external_launcher": + import os + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + logger.info("Disabling V1 multiprocessing for external launcher.") + + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now.") + if self.num_redundant_experts < 0: + raise ValueError( + "num_redundant_experts must be non-negative, but got " + f"{self.num_redundant_experts}.") + else: + if self.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts should be used with EPLB." + f"{self.num_redundant_experts}.") + if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + + from vllm.executor import ray_utils + backend: DistributedExecutorBackend = "mp" + ray_found = ray_utils.ray_is_available() + if current_platform.is_neuron(): + # neuron uses single process to control multiple devices + backend = "uni" + elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: + backend = "uni" + elif (current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size): + if not ray_found: + raise ValueError("Unable to load Ray which is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`.") from ray_utils.ray_import_err + backend = "ray" + elif self.data_parallel_backend == "ray": + logger.info("Using ray distributed inference because " + "data_parallel_backend is ray") + backend = "ray" + elif ray_found: + if self.placement_group: + backend = "ray" + else: + from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): + from ray.util import get_current_placement_group + if get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.debug("Defaulting to use %s for distributed inference", + backend) + + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + + @property + def use_ray(self) -> bool: + return self.distributed_executor_backend == "ray" or ( + isinstance(self.distributed_executor_backend, type) + and self.distributed_executor_backend.uses_ray) + + @model_validator(mode='after') + def _verify_args(self) -> Self: + # Lazy import to avoid circular import + from vllm.executor.executor_base import ExecutorBase + from vllm.platforms import current_platform + if self.distributed_executor_backend not in ( + "ray", "mp", "uni", + "external_launcher", None) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): + raise ValueError( + "Unrecognized distributed executor backend " + f"{self.distributed_executor_backend}. Supported " + "values are 'ray', 'mp' 'uni', 'external_launcher' or" + " custom ExecutorBase subclass.") + if self.use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + + # if not current_platform.use_custom_allreduce(): + # self.disable_custom_all_reduce = True + # logger.debug( + # "Disabled the custom all-reduce kernel because it is not " + # "supported on current platform.") + if self.ray_workers_use_nsight and not self.use_ray: + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") + + return self + + +PreemptionMode = Literal["swap", "recompute"] +SchedulerPolicy = Literal["fcfs", "priority"] + + +@config +@dataclass +class SchedulerConfig: + """Scheduler configuration.""" + + runner_type: RunnerType = "generate" + """The runner type to launch for the model.""" + + max_num_batched_tokens: SkipValidation[int] = None # type: ignore + """Maximum number of tokens to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" + + max_num_seqs: SkipValidation[int] = None # type: ignore + """Maximum number of sequences to be processed in a single iteration. + + This config has no static default. If left unspecified by the user, it will + be set in `EngineArgs.create_engine_config` based on the usage context.""" + + max_model_len: SkipValidation[int] = None # type: ignore + """Maximum length of a sequence (including prompt and generated text). This + is primarily set in `ModelConfig` and that value should be manually + duplicated here.""" + + max_num_partial_prefills: int = 1 + """For chunked prefill, the maximum number of sequences that can be + partially prefilled concurrently.""" + + max_long_partial_prefills: int = 1 + """For chunked prefill, the maximum number of prompts longer than + long_prefill_token_threshold that will be prefilled concurrently. Setting + this less than max_num_partial_prefills will allow shorter prompts to jump + the queue in front of longer prompts in some cases, improving latency.""" + + long_prefill_token_threshold: int = 0 + """For chunked prefill, a request is considered long if the prompt is + longer than this number of tokens.""" + + num_lookahead_slots: int = 0 + """The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. + + NOTE: This will be replaced by speculative config in the future; it is + present to enable correctness tests until then.""" + + cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) + """Cuda graph capture sizes, default is 512. + 1. if one value is provided, then the capture list would follow the + pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] + 2. more than one value (e.g. 1 2 128) is provided, then the capture list + will follow the provided list.""" + + delay_factor: float = 0.0 + """Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt.""" + + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens.""" + + is_multimodal_model: bool = False + """True if the model is multimodal.""" + + # TODO (ywang96): Make this configurable. + max_num_encoder_input_tokens: int = field(init=False) + """Multimodal encoder compute budget, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" + + # TODO (ywang96): Make this configurable. + encoder_cache_size: int = field(init=False) + """Multimodal encoder cache size, only used in V1. + + NOTE: This is not currently configurable. It will be overridden by + max_num_batched_tokens in case max multimodal embedding size is larger.""" + + preemption_mode: Optional[PreemptionMode] = None + """Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead.""" + + num_scheduler_steps: int = 1 + """Maximum number of forward steps per scheduler call.""" + + multi_step_stream_outputs: bool = True + """If False, then multi-step will stream outputs at the end of all steps""" + + send_delta_data: bool = False + """Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1""" + + policy: SchedulerPolicy = "fcfs" + """The scheduling policy to use:\n + - "fcfs" means first come first served, i.e. requests are handled in order + of arrival.\n + - "priority" means requests are handled based on given priority (lower + value means earlier handling) and time of arrival deciding any ties).""" + + chunked_prefill_enabled: bool = field(init=False) + """True if chunked prefill is enabled.""" + + disable_chunked_mm_input: bool = False + """If set to true and chunked prefill is enabled, we do not want to + partially schedule a multimodal item. Only used in V1 + This ensures that if a request has a mixed prompt + (like text tokens TTTT followed by image tokens IIIIIIIIII) where only + some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), + it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" + + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". + scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" + """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the + default scheduler. Can be a class directly or the path to a class of form + "mod.custom_class".""" + + disable_hybrid_kv_cache_manager: bool = False + """If set to True, KV cache manager will allocate the same size of KV cache + for all attention layers even if there are multiple type of attention layers + like full attention and sliding window attention. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.max_model_len is None: + self.max_model_len = 8192 + + if self.max_num_seqs is None: + self.max_num_seqs = 128 + + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + self.max_num_batched_tokens = max( + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + else: + self.max_num_batched_tokens = ( + DEFAULT_MAX_NUM_BATCHED_TOKENS) + else: + # If max_model_len is too short, use + # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value + # for higher throughput. + self.max_num_batched_tokens = max( + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.runner_type == "pooling": + # Choose specific value for higher throughput + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if self.is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * self.max_model_len, + self.max_num_batched_tokens) + + self.max_num_encoder_input_tokens = self.max_num_batched_tokens + self.encoder_cache_size = self.max_num_batched_tokens + + if self.enable_chunked_prefill: + logger.info( + "Chunked prefill is enabled with max_num_batched_tokens=%d.", + self.max_num_batched_tokens) + + self.chunked_prefill_enabled = self.enable_chunked_prefill + if self.max_num_partial_prefills > 1: + if self.long_prefill_token_threshold == 0: + self.long_prefill_token_threshold = int(self.max_model_len * + 0.04) + + logger.info( + "Concurrent partial prefills enabled with " + "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " + "long_prefill_token_threshold=%d", + self.max_num_partial_prefills, self.max_long_partial_prefills, + self.long_prefill_token_threshold) + + @model_validator(mode='after') + def _verify_args(self) -> Self: + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + logger.warning( + "max_num_batched_tokens (%d) exceeds max_num_seqs " + "* max_model_len (%d). This may lead to unexpected behavior.", + self.max_num_batched_tokens, + self.max_num_seqs * self.max_model_len) + + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + if self.max_num_partial_prefills < 1: + raise ValueError( + f"max_num_partial_prefills ({self.max_num_partial_prefills}) " + "must be greater than or equal to 1.") + elif self.max_num_partial_prefills > 1: + if not self.chunked_prefill_enabled: + raise ValueError("Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1.") + + if self.long_prefill_token_threshold > self.max_model_len: + raise ValueError( + "long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than the max_model_len ({self.max_model_len}).") + + if (self.max_long_partial_prefills + < 1) or (self.max_long_partial_prefills + > self.max_num_partial_prefills): + raise ValueError( + f"max_long_partial_prefills ({self.max_long_partial_prefills}) " + "must be greater than or equal to 1 and less than or equal to " + f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + + return self + + @property + def is_multi_step(self) -> bool: + return self.num_scheduler_steps > 1 + + +Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class DeviceConfig: + """Configuration for the device to use for vLLM execution.""" + + device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" + """Device type for vLLM execution. + This parameter is deprecated and will be + removed in a future release. + It will now be set automatically based + on the current platform.""" + device_type: str = field(init=False) + """Device type from the current platform. This is set in + `__post_init__`.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # the device/platform information will be summarized + # by torch/vllm automatically. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if self.device == "auto": + # Automated device type detection + from vllm.platforms import current_platform + self.device_type = current_platform.device_type + if not self.device_type: + raise RuntimeError( + "Failed to infer device type, please set " + "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " + "to turn on verbose logging to help debug the issue.") + else: + # Device type is assigned explicitly + if isinstance(self.device, str): + self.device_type = self.device + elif isinstance(self.device, torch.device): + self.device_type = self.device.type + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", + "mlp_speculator", "draft_model", "deepseek_mtp"] +SpeculativeAcceptanceMethod = Literal["rejection_sampler", + "typical_acceptance_sampler"] + + +@config +@dataclass +class SpeculativeConfig: + """Configuration for speculative decoding.""" + + # General speculative decoding control + num_speculative_tokens: SkipValidation[int] = None # type: ignore + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler" + """The method to use for accepting draft tokens:\n + - "rejection_sampler" maps to `RejectionSampler`.\n + - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`. + + If using `typical_acceptance_sampler`, the related configuration + `posterior_threshold` and `posterior_alpha` should be considered.""" + draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" + disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" + + # Draft model configuration + quantization: Optional[me_quant.QuantizationMethods] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" + max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" + revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" + code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + + # Advanced control + disable_mqa_scorer: bool = False + """Disable the MQA scorer and fall back to batch expansion for scoring + proposals.""" + disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + + # Ngram proposer configuration + prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" + prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + # Typical acceptance sampler configuration + posterior_threshold: Optional[float] = None + """A threshold value that sets a lower bound on the posterior probability + of a token in the target model for it to be accepted. This threshold is + used only when we use the `TypicalAcceptanceSampler` for token acceptance. + """ + posterior_alpha: Optional[float] = None + """Scaling factor for entropy-based threshold, applied when using + `TypicalAcceptanceSampler`.""" + + speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ + # required configuration params passed from engine + target_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the target model.""" + target_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore + """The parallel configuration for the target model.""" + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" + disable_log_stats: SkipValidation[bool] = None # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" + + # params generated in the post-init stage + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore + """The configuration of the draft model initialized internal.""" + draft_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore + """The parallel configuration for the draft model initialized internal.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + # Eagle3 affects the computation graph because it returns intermediate + # hidden states in addition to the final hidden state. + factors.append(self.method == "eagle3") + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + @classmethod + def from_dict(cls, dict_value: dict) -> "SpeculativeConfig": + """Parse the CLI value for the speculative config.""" + return cls(**dict_value) + + @staticmethod + def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: + if hf_config.model_type == "deepseek_v3": + hf_config.model_type = "deepseek_mtp" + if hf_config.model_type == "deepseek_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["DeepSeekMTPModel"] + }) + + if hf_config.architectures[0] == "MiMoForCausalLM": + hf_config.model_type = "mimo_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"] + }) + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"] + }) + + return hf_config + + def __post_init__(self): + + # Note: "method" is a new parameter that helps to extend the + # configuration of non-model-based proposers, and the "model" parameter + # will be used to set the draft model, eagle head, or additional weight + # when needed. If users do not specify "method", the speculative method + # will be detected automatically if possible. If the speculative method + # can not be detected, it will be considered as the "draft_model" by + # default. + + if self.model is None and self.num_speculative_tokens is not None: + # TODO(Shangming): Refactor mtp configuration logic when supporting + # mtp acceleration for more models besides deepseek_v3 + if self.target_model_config and \ + (self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3" or + self.target_model_config.hf_text_config.model_type \ + == "mimo"): + # use the draft model from the same model: + self.model = self.target_model_config.model + elif self.method in ("ngram", "[ngram]"): + self.model = "ngram" + else: + raise ValueError("num_speculative_tokens was provided without " + "speculative model.") + + # Automatically configure the method for ngram when "model" is used + # instead of "method" + if self.method is None and (self.model is not None + and self.model in ("ngram", "[ngram]")): + self.method = "ngram" + + if self.method in ("ngram", "[ngram]"): + # Unified to "ngram" internally + self.method = "ngram" + # Set default values if not provided + if (self.prompt_lookup_min is None + and self.prompt_lookup_max is None): + # TODO(woosuk): Tune these values. They are arbitrarily chosen. + self.prompt_lookup_min = 5 + self.prompt_lookup_max = 5 + elif self.prompt_lookup_min is None: + assert self.prompt_lookup_max is not None + self.prompt_lookup_min = self.prompt_lookup_max + elif self.prompt_lookup_max is None: + assert self.prompt_lookup_min is not None + self.prompt_lookup_max = self.prompt_lookup_min + + # Validate values + if self.prompt_lookup_min < 1: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + if self.prompt_lookup_max < 1: + raise ValueError( + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + if self.prompt_lookup_min > self.prompt_lookup_max: + raise ValueError( + f"prompt_lookup_min={self.prompt_lookup_min} must " + f"be <= prompt_lookup_max={self.prompt_lookup_max}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + else: + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if self.model is not None: + self.draft_model_config = ModelConfig( + model=self.model, + task="draft", + tokenizer=self.target_model_config.tokenizer, + tokenizer_mode=self.target_model_config.tokenizer_mode, + trust_remote_code=self.target_model_config. + trust_remote_code, + allowed_local_media_path=self.target_model_config. + allowed_local_media_path, + dtype=self.target_model_config.dtype, + seed=self.target_model_config.seed, + revision=self.revision, + code_revision=self.code_revision, + tokenizer_revision=self.target_model_config. + tokenizer_revision, + spec_target_max_model_len=self.target_model_config. + max_model_len, + quantization=self.quantization, + enforce_eager=True if envs.VLLM_SPEC_DECODE_EAGER else self.target_model_config.enforce_eager, + max_seq_len_to_capture=self.target_model_config. + max_seq_len_to_capture, + max_logprobs=self.target_model_config.max_logprobs, + hf_overrides=SpeculativeConfig.hf_config_override, + ) + + # Automatically detect the method + if self.method in ('eagle', 'eagle3'): + pass + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): + self.method = "eagle" + elif self.draft_model_config.hf_config.model_type == "medusa": + self.method = "medusa" + elif (self.draft_model_config.hf_config.model_type == + "mlp_speculator"): + self.method = "mlp_speculator" + elif (self.draft_model_config.hf_config.model_type == + "deepseek_mtp", "glm4_moe_mtp"): + self.method = "deepseek_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Deepseek MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) + else: + self.method = "draft_model" + + # Replace hf_config for EAGLE draft_model + if self.method in ("eagle", "eagle3"): + if self.enable_chunked_prefill and not envs.VLLM_USE_V1: + raise ValueError( + "Chunked prefill and EAGLE are not compatible " + "when using V0.") + + from vllm.transformers_utils.configs.eagle import ( + EAGLEConfig) + if isinstance(self.draft_model_config.hf_config, + EAGLEConfig): + pass + else: + eagle_config = EAGLEConfig( + self.draft_model_config.hf_config, + method=self.method, + model_type="eagle") + self.draft_model_config.hf_config = eagle_config + + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens + + # if (self.num_speculative_heads is not None + # and hasattr(self.draft_model_config.hf_config, "num_lookahead_heads")): + # self.draft_model_config.hf_config.num_lookahead_heads = self.num_speculative_heads + + n_predict = getattr(self.draft_model_config.hf_config, + "n_predict", None) + if n_predict is not None: + if self.num_speculative_tokens is None: + # Default to max value defined in draft model config. + self.num_speculative_tokens = n_predict + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: + # Ensure divisibility for MTP module reuse. + raise ValueError( + f"num_speculative_tokens:{self.num_speculative_tokens}" + f" must be divisible by {n_predict=}") + + self.draft_tensor_parallel_size = \ + SpeculativeConfig._verify_and_get_draft_tp( + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_model_config.hf_config + ) + + self.draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + self.max_model_len, + self.draft_model_config.max_model_len, + self.target_model_config.max_model_len, + )) + + self.draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + self.target_parallel_config, + self.draft_tensor_parallel_size)) + + if self.acceptance_method == "typical_acceptance_sampler": + if self.posterior_threshold is None: + self.posterior_threshold = 0.09 + if self.posterior_alpha is None: + self.posterior_alpha = 0.3 + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + + @staticmethod + def _verify_and_get_draft_tp( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: + """ + Verifies and adjusts the tensor parallel size for a draft model + specified using speculative_draft_tensor_parallel_size. + """ + # If speculative_draft_tensor_parallel_size is unset then set it + # appropriately else verify that it is set correctly. + if speculative_draft_tensor_parallel_size is None: + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "%s cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1", + draft_hf_config.model_type) + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size + elif speculative_draft_tensor_parallel_size not in ( + 1, target_parallel_config.tensor_parallel_size): + raise ValueError( + f"{speculative_draft_tensor_parallel_size=} cannot be " + f"other value than 1 or target model tensor_parallel_size") + return speculative_draft_tensor_parallel_size + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: int, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + @model_validator(mode='after') + def _verify_args(self) -> Self: + if self.num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative model unless the draft model config contains an " + "n_predict parameter.") + + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + # Validate and set draft token acceptance related settings. + + if self.acceptance_method is None: + raise ValueError("acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler.") + + if (self.acceptance_method != 'rejection_sampler' + and self.acceptance_method != 'typical_acceptance_sampler'): + raise ValueError( + "Expected acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.acceptance_method}") + + if self.acceptance_method == "typical_acceptance_sampler" and ( + (self.posterior_threshold is not None + and self.posterior_threshold < 0) or + (self.posterior_alpha is not None and self.posterior_alpha < 0)): + raise ValueError( + "Expected the posterior_threshold and posterior_alpha of " + "typical_acceptance_sampler to be > 0. " + "Instead found posterior_threshold = " + f"{self.posterior_threshold} and posterior_alpha = " + f"{self.posterior_alpha}") + + if (self.disable_by_batch_size is not None + and self.disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}") + + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: + raise ValueError( + "Eagle3 is only supported for Llama models. " + f"Got {self.target_model_config.hf_text_config.model_type=}") + + return self + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3", "deepseek_mtp") + + def __repr__(self) -> str: + method = self.method + model = None if method == "ngram" else self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" + + +LoRADType = Literal["auto", "float16", "bfloat16"] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class LoRAConfig: + """Configuration for LoRA.""" + + max_lora_rank: int = 16 + """Max LoRA rank.""" + max_loras: int = 1 + """Max number of LoRAs in a single batch.""" + fully_sharded_loras: bool = False + """By default, only half of the LoRA computation is sharded with tensor + parallelism. Enabling this will use the fully sharded layers. At high + sequence length, max rank or tensor parallel size, this is likely faster. + """ + max_cpu_loras: Optional[int] = None + """Maximum number of LoRAs to store in CPU memory. Must be >= than + `max_loras`.""" + lora_target_modules: Optional[List[str]] = None + """List of lora module name, If not specified, + modules will be chosen according to the model architecture. + """ + lora_dtype: Union[torch.dtype, LoRADType] = "auto" + """Data type for LoRA. If auto, will default to base model dtype.""" + lora_extra_vocab_size: int = 256 + """Maximum size of extra vocabulary that can be present in a LoRA adapter + (added to the base model vocabulary).""" + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() + long_lora_scaling_factors: Optional[tuple[float, ...]] = None + """Specify multiple scaling factors (which can be different from base model + scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters + trained with those scaling factors to be used at the same time. If not + specified, only adapters trained with the base model scaling factor are + allowed.""" + bias_enabled: bool = False + """Enable bias for LoRA adapters.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.max_lora_rank) + factors.append(self.max_loras) + factors.append(self.fully_sharded_loras) + factors.append(self.lora_dtype) + factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) + factors.append(self.long_lora_scaling_factors) + factors.append(self.bias_enabled) + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + # Setting the maximum rank to 512 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512) + possible_lora_extra_vocab_size = (256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_cache_config(self, cache_config: CacheConfig): + if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: + raise ValueError( + "V0 LoRA does not support CPU offload, please use V1.") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_lora_support(self): + if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1: + raise ValueError( + "V1 LoRA does not support long LoRA, please use V0.") + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class PromptAdapterConfig: + """Configuration for PromptAdapters.""" + + max_prompt_adapters: int = 1 + """Max number of PromptAdapters in a batch.""" + max_prompt_adapter_token: int = 0 + """Max number of PromptAdapters tokens.""" + max_cpu_prompt_adapters: Optional[int] = None + """Maximum number of PromptAdapters to store in CPU memory. Must be >= than + `max_prompt_adapters`.""" + prompt_adapter_dtype: Union[torch.dtype, str] = "auto" + """Data type for PromptAdapter. If auto, will default to base model dtype. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype == "auto": + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + +@config +@dataclass +class MultiModalConfig: + """Controls the behavior of multimodal models.""" + + limit_per_prompt: dict[str, int] = \ + cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) + """ + The maximum number of input items allowed per prompt for each modality. + Defaults to 1 (V0) or 999 (V1) for each modality. + + For example, to allow up to 16 images and 2 videos per prompt: + `{"images": 16, "videos": 2}` + """ + + media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set + `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ + + mm_processor_kwargs: Optional[dict[str, object]] = None + """ + Overrides for the multi-modal processor obtained from + `transformers.AutoProcessor.from_pretrained`. + + The available overrides depend on the model that is being run. + + For example, for Phi-3-Vision: + `{"num_crops": 4}`. + """ + + disable_mm_preprocessor_cache: bool = False + """ + If `True`, disable caching of the processed multi-modal inputs. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def get_limit_per_prompt(self, modality: str) -> int: + """ + Get the maximum number of input items allowed per prompt + for the given modality. + """ + return self.limit_per_prompt.get( + modality, + 999 if envs.VLLM_USE_V1 else 1, + ) + + # TODO: Add configs to init vision tower or not. + + +@config +@dataclass +class PoolerConfig: + """Controls the behavior of output pooling in pooling models.""" + + pooling_type: Optional[str] = None + """ + The pooling method of the pooling model. This should be a key in + [`vllm.model_executor.layers.pooler.PoolingType`][]. + """ + + normalize: Optional[bool] = None + """ + Whether to normalize the pooled outputs. Usually, this should be set to + ``True`` for embedding outputs. + """ + + softmax: Optional[bool] = None + """ + Whether to apply softmax to the pooled outputs. Usually, this should be set + to ``True`` for classification outputs. + """ + + step_tag_id: Optional[int] = None + """ + If set, only the score corresponding to the ``step_tag_id`` in the + generated sentence should be returned. Otherwise, the scores for all tokens + are returned. + """ + + returned_token_ids: Optional[list[int]] = None + """ + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the + ``math-shepherd-mistral-7b-prm`` model. + """ + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +# model_type -> reason +_FLOAT16_NOT_SUPPORTED_MODELS = { + "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", + "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", + "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", +} + + +def _is_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103 + return False + + return True + + +def _check_valid_dtype(model_type: str, dtype: torch.dtype): + if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: + reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] + raise ValueError(f"The model type {model_type!r} " + f"does not support float16. Reason: {reason}") + + return True + + +def _find_dtype( + model_id: str, + config: PretrainedConfig, + *, + revision: Optional[str], +): + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + + # Fallbacks for multi-modal models if the root config + # does not define torch_dtype + if config_dtype is None: + config_dtype = getattr(config.get_text_config(), "torch_dtype", None) + if config_dtype is None and hasattr(config, "vision_config"): + config_dtype = getattr(config.vision_config, "torch_dtype", None) + if config_dtype is None and hasattr(config, "encoder_config"): + config_dtype = getattr(config.encoder_config, "torch_dtype", None) + + # Try to read the dtype of the weights if they are in safetensors format + if config_dtype is None: + repo_mt = try_get_safetensors_metadata(model_id, revision=revision) + + if repo_mt and (files_mt := repo_mt.files_metadata): + param_dtypes: set[torch.dtype] = { + _SAFETENSORS_TO_TORCH_DTYPE[dtype_str] + for file_mt in files_mt.values() + for dtype_str in file_mt.parameter_count + if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE + } + + if param_dtypes: + return common_broadcastable_dtype(param_dtypes) + + if config_dtype is None: + config_dtype = torch.float32 + + return config_dtype + + +def _resolve_auto_dtype( + model_type: str, + config_dtype: torch.dtype, + *, + is_pooling_model: bool, +): + from vllm.platforms import current_platform + + supported_dtypes = [ + dtype for dtype in current_platform.supported_dtypes + if _is_valid_dtype(model_type, dtype) + ] + + if is_pooling_model and torch.float16 in supported_dtypes: + preferred_dtype = torch.float16 + else: + preferred_dtype = supported_dtypes[0] + + # Downcast for float32 models + if config_dtype == torch.float32: + config_dtype = preferred_dtype + + if config_dtype in supported_dtypes: + return config_dtype + + # Ensure device compatibility + device_name = current_platform.get_device_name() + device_capability = current_platform.get_device_capability() + + if device_capability is None: + device_str = f"{device_name!r}" + else: + version_str = device_capability.as_version_str() + device_str = f"{device_name!r} (with compute capability {version_str})" + + logger.warning( + "Your device %s doesn't support %s. " + "Falling back to %s for compatibility.", + device_str, + config_dtype, + preferred_dtype, + ) + + return preferred_dtype + + +def _get_and_verify_dtype( + model_id: str, + config: PretrainedConfig, + dtype: Union[str, torch.dtype], + *, + is_pooling_model: bool, + revision: Optional[str] = None, +) -> torch.dtype: + config_dtype = _find_dtype(model_id, config, revision=revision) + model_type = config.model_type + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + # Set default dtype from model config + torch_dtype = _resolve_auto_dtype( + model_type, + config_dtype, + is_pooling_model=is_pooling_model, + ) + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype!r}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + _check_valid_dtype(model_type, torch_dtype) + + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + tokenizer_config: Optional[dict], + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[Union[int, list[Optional[int]]]], + spec_target_max_model_len: Optional[int] = None, + encoder_config: Optional[Any] = None, +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Whisper + "max_target_positions", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + # For Command-R / Cohere, Cohere2 / Aya Vision models + if tmp_max_len := getattr(hf_config, "model_max_length", None): + max_len_key = "model_max_length" + derived_max_model_len = tmp_max_len + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + + sliding_window_len_min = get_min_sliding_window(sliding_window_len) + max_len_key = "sliding_window" \ + if sliding_window_len_min < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, + sliding_window_len_min) + + # Consider model_max_length in tokenizer_config + if tokenizer_config: + tokenizer_model_max_length = tokenizer_config.get( + "model_max_length", derived_max_model_len) + derived_max_model_len = min(derived_max_model_len, + tokenizer_model_max_length) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + if spec_target_max_model_len is not None: + # If this is a speculative draft model, we use the max model len + # from the target model. + return spec_target_max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + # NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE + # scaling, so we skip applying the scaling factor again. + if rope_scaling is not None and "gemma3" not in hf_config.model_type: + # No need to consider "type" key because of patch_rope_scaling when + # loading HF config + rope_type = rope_scaling["rope_type"] + + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + # NOTE: rope_type == "default" does not define factor + # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py + scaling_factor = rope_scaling.get("factor", 1.0) + + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if encoder_config and "max_seq_length" in encoder_config: + derived_max_model_len = encoder_config["max_seq_length"] + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + if current_platform.is_tpu(): + logger.warning( + "--max-model-len is not specified, " + "it's currently using model's default length %s, " + "which might be too large." + "Please input with --max-model-len based on your " + "request input length and output length, to avoid " + "unnecessary degradation.", max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +def get_min_sliding_window( + sliding_window: Union[int, list[Optional[int]]]) -> int: + if isinstance(sliding_window, list): + return min(s for s in sliding_window if s is not None) + + return sliding_window + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, list[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", + "xgrammar", "guidance"] +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] +GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, + GuidedDecodingBackendV1] + + +@config +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine.""" + + @property + @deprecated( + "`guided_decoding_backend` is deprecated and has been renamed to " + "`backend`. This will be removed in v0.10.0. Please use the " + "`backend` argument instead.") + def guided_decoding_backend(self) -> GuidedDecodingBackend: + return self.backend + + @guided_decoding_backend.setter + def guided_decoding_backend(self, value: GuidedDecodingBackend): + self.backend = value + + backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar" + """Which engine will be used for guided decoding (JSON schema / regex etc) + by default. With "auto", we will make opinionated choices based on request + contents and what the backend libraries currently support, so the behavior + is subject to change in each release.""" + + disable_fallback: bool = False + """If `True`, vLLM will not fallback to a different backend on error.""" + + disable_any_whitespace: bool = False + """If `True`, the model will not generate any whitespace during guided + decoding. This is only supported for xgrammar and guidance backends.""" + + disable_additional_properties: bool = False + """If `True`, the `guidance` backend will not use `additionalProperties` + in the JSON schema. This is only supported for the `guidance` backend and + is used to better align its behaviour with `outlines` and `xgrammar`.""" + + reasoning_backend: str = "" + """Select the reasoning parser depending on the model that you're using. + This is used to parse the reasoning content into OpenAI API format.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if ":" in self.backend: + self._extract_backend_options() + + if envs.VLLM_USE_V1: + valid_guided_backends = get_args(GuidedDecodingBackendV1) + else: + valid_guided_backends = get_args(GuidedDecodingBackendV0) + if self.backend not in valid_guided_backends: + raise ValueError(f"Invalid backend '{self.backend}'," + f" must be one of {valid_guided_backends}") + if (self.disable_any_whitespace + and self.backend not in ("xgrammar", "guidance")): + raise ValueError("disable_any_whitespace is only supported for " + "xgrammar and guidance backends.") + if (self.disable_additional_properties and self.backend != "guidance"): + raise ValueError("disable_additional_properties is only supported " + "for the guidance backend.") + + @deprecated( + "Passing guided decoding backend options inside backend in the format " + "'backend:...' is deprecated. This will be removed in v0.10.0. Please " + "use the dedicated arguments '--disable-fallback', " + "'--disable-any-whitespace' and '--disable-additional-properties' " + "instead.") + def _extract_backend_options(self): + """Extract backend options from the backend string.""" + backend, options = self.backend.split(":") + self.backend = cast(GuidedDecodingBackend, backend) + options_set = set(options.strip().split(",")) + if "no-fallback" in options_set: + self.disable_fallback = True + if "disable-any-whitespace" in options_set: + self.disable_any_whitespace = True + if "no-additional-properties" in options_set: + self.disable_additional_properties = True + + +DetailedTraceModules = Literal["model", "worker", "all"] + + +@config +@dataclass +class ObservabilityConfig: + """Configuration for observability - metrics and tracing.""" + + show_hidden_metrics_for_version: Optional[str] = None + """Enable deprecated Prometheus metrics that have been hidden since the + specified version. For example, if a previously deprecated metric has been + hidden since the v0.7.0 release, you use + `--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while + you migrate to new metrics. The metric is likely to be removed completely + in an upcoming release.""" + + @cached_property + def show_hidden_metrics(self) -> bool: + """Check if the hidden metrics should be shown.""" + if self.show_hidden_metrics_for_version is None: + return False + return version._prev_minor_version_was( + self.show_hidden_metrics_for_version) + + otlp_traces_endpoint: Optional[str] = None + """Target URL to which OpenTelemetry traces will be sent.""" + + collect_detailed_traces: Optional[list[DetailedTraceModules]] = None + """It makes sense to set this only if `--otlp-traces-endpoint` is set. If + set, it will collect detailed traces for the specified modules. This + involves use of possibly costly and or blocking operations and hence might + have a performance impact. + + Note that collecting detailed timing information for each request can be + expensive.""" + + @cached_property + def collect_model_forward_time(self) -> bool: + """Whether to collect model forward time for the request.""" + return (self.collect_detailed_traces is not None + and ("model" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces)) + + @cached_property + def collect_model_execute_time(self) -> bool: + """Whether to collect model execute time for the request.""" + return (self.collect_detailed_traces is not None + and ("worker" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces)) + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self): + if (self.collect_detailed_traces is not None + and len(self.collect_detailed_traces) == 1 + and "," in self.collect_detailed_traces[0]): + self._parse_collect_detailed_traces() + + from vllm.tracing import is_otel_available, otel_import_error_traceback + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}") + + def _parse_collect_detailed_traces(self): + assert isinstance(self.collect_detailed_traces, list) + self.collect_detailed_traces = cast( + list[DetailedTraceModules], + self.collect_detailed_traces[0].split(",")) + + +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: + """Configuration for distributed KV cache transfer.""" + + kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ + + engine_id: Optional[str] = None + """The engine id for KV transfers.""" + + kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" + + kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" + + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'kv_both'.""" + + kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" + + kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + PyNcclConnector, this should be 2.""" + + kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" + + kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" + + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" + + kv_connector_module_path: Optional[str] = None + """The Python module path to dynamically load the KV connector from. + Only supported in V1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # no factors to consider. + # this config will not affect the computation graph. + factors: list[Any] = [] + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + return hash_str + + def __post_init__(self) -> None: + if self.engine_id is None: + self.engine_id = str(uuid.uuid4()) + + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVRole) + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVProducer) + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in get_args(KVConsumer) + + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) + + +@config +@dataclass +class KVEventsConfig: + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ + + +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +@config +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + dump_graph_stages: list[str] = field(default_factory=list) + """List of stages for which we want to dump the graph. Each pass defines + its own stages (before, after, maybe in-between).""" + dump_graph_dir: Path = Path(".") + """Directory to dump the graphs.""" + enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) + """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" + enable_attn_fusion: bool = False + """Whether to enable the custom attention+quant fusion pass.""" + enable_noop: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) + """Whether to enable the custom no-op elimination pass.""" + enable_sequence_parallelism: bool = False + """Whether to enable sequence parallelism.""" + enable_async_tp: bool = False + """Whether to enable async TP.""" + + # TODO(luka) better pass enabling system. + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + exclude = {"dump_graph_stages", "dump_graph_dir"} + dict_ = {k: v for k, v in asdict(self).items() if k not in exclude} + return InductorPass.hash_dict(dict_) + + def __post_init__(self) -> None: + if not self.enable_noop: + if self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm/SiluMul + quant (fp8) fusion might not work") + if self.enable_attn_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "Attention + quant (fp8) fusion might not work") + + +@config +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + + - Top-level Compilation control: + - [`level`][vllm.config.CompilationConfig.level] + - [`debug_dump_path`][vllm.config.CompilationConfig.debug_dump_path] + - [`cache_dir`][vllm.config.CompilationConfig.cache_dir] + - [`backend`][vllm.config.CompilationConfig.backend] + - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] + - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] + - CudaGraph capture: + - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_capture_sizes`] + [vllm.config.CompilationConfig.cudagraph_capture_sizes] + - [`cudagraph_num_of_warmups`] + [vllm.config.CompilationConfig.cudagraph_num_of_warmups] + - [`cudagraph_copy_inputs`] + [vllm.config.CompilationConfig.cudagraph_copy_inputs] + - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] + - Inductor compilation: + - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] + - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] + - [`inductor_compile_config`] + [vllm.config.CompilationConfig.inductor_compile_config] + - [`inductor_passes`][vllm.config.CompilationConfig.inductor_passes] + - custom inductor passes + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ + # Top-level Compilation control + level: int = 0 + """The level of compilation: + + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation.""" + debug_dump_path: str = "" + """The path to dump the debug information.""" + cache_dir: str = "" + """The directory to store the compiled graph, to accelerate Inductor + compilation. By default, it will use model-related information to generate + a cache directory.""" + backend: str = "" + """The backend for compilation. It needs to be a string: + + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the + + backend function. + We use string to avoid serialization issues when using compilation in a + distributed setting. When the compilation level is 1 or 2, the backend is + used for the compilation directly (it sees the whole graph). When the + compilation level is 3, the backend is used for the piecewise compilation + (it sees a part of the graph).""" + custom_ops: list[str] = field(default_factory=list) + """Fine-grained control over which custom ops to enable/disable. Use 'all' + to enable all, 'none' to disable all. Also specify a list of custom op + names to enable (prefixed with a '+'), or disable (prefixed with a '-'). + Examples: + + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + + By default, all custom ops are enabled when running without Inductor and + disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. + Inductor generates (fused) Triton kernels for disabled custom ops.""" + splitting_ops: list[str] = field(default_factory=list) + """A list of ops to split the full graph into subgraphs, used in piecewise + compilation.""" + + # Inductor capture + use_inductor: bool = True + """Whether to use inductor compilation: + + - False: inductor compilation is not used. graph runs in eager + (custom_ops enabled by default). + - True: inductor compilation is used (custom_ops disabled by default). + One graph for symbolic shape and one graph per size in compile_sizes + are compiled using configurations in inductor_compile_config. + + This setting is ignored if level1.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + factors.append(self.level) + factors.append(self.backend) + factors.append(self.custom_ops) + factors.append(self.splitting_ops) + factors.append(self.use_inductor) + factors.append(self.inductor_compile_config) + factors.append(self.inductor_passes) + factors.append(self.pass_config.uuid()) + return hashlib.sha256(str(factors).encode()).hexdigest() + + def __repr__(self) -> str: + exclude = { + "static_forward_context": True, + "enabled_custom_ops": True, + "disabled_custom_ops": True, + "compilation_time": True, + "bs_to_padded_graph_size": True, + "pass_config": True, + "traced_files": True, + "inductor_compile_config": { + "post_grad_custom_post_pass": True, + }, + } + # The cast to string is necessary because Pydantic is mocked in docs + # builds and sphinx-argparse doesn't know the return type of decode() + return str( + TypeAdapter(CompilationConfig).dump_json( + self, + exclude=exclude, # type: ignore[arg-type] + exclude_unset=True).decode()) + + __str__ = __repr__ + + @classmethod + def from_cli(cls, cli_value: str) -> "CompilationConfig": + """Parse the CLI value for the compilation config. + -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. + """ + return TypeAdapter(CompilationConfig).validate_json(cli_value) + + def __post_init__(self) -> None: + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + # TODO(zou3519/luka): There are 2 issues with auto-functionalization V2: + # 1. A bug in PyTorch, fixed in 2.7: + # https://github.com/pytorch/pytorch/issues/147924 + # 2. Custom passes (fusion) rely on auto-functionalization V1 and don't + # work with V2. Addressing this will take extra engineering effort + # and it is not yet a priority. RFC here: + # https://github.com/vllm-project/vllm/issues/14703 + + if is_torch_equal_or_newer("2.6"): + KEY = 'enable_auto_functionalized_v2' + if KEY not in self.inductor_compile_config: + self.inductor_compile_config[KEY] = False + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) + + if isinstance(self.pass_config, dict): + self.pass_config = PassConfig(**self.pass_config) + + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: + if self.level == CompilationLevel.NO_COMPILATION: + raise ValueError("No compilation level is set.") + + from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) + if self.level in [ + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE + ]: + if self.backend == "": + return "eager" + if self.backend in torch_backends: + return self.backend + return resolve_obj_by_qualname(self.backend) + + # TODO: pass user-specified backend to piecewise compilation + # merge with the config use_inductor + assert self.level == CompilationLevel.PIECEWISE + + from vllm.compilation.backends import VllmBackend + return VllmBackend(vllm_config) + + def init_with_cudagraph_sizes(self, + cudagraph_capture_sizes: list[int]) -> None: + """To complete the initialization of config, + we need to know the cudagraph sizes.""" + + if self.cudagraph_capture_sizes is None: + self.cudagraph_capture_sizes = cudagraph_capture_sizes + else: + # de-duplicate the sizes provided by the config + dedup_sizes = list(set(self.cudagraph_capture_sizes)) + if len(dedup_sizes) < len(self.cudagraph_capture_sizes): + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, dedup_sizes) + self.cudagraph_capture_sizes = dedup_sizes + + computed_compile_sizes = [] + if self.compile_sizes is not None: + # de-duplicate the sizes provided by the config + self.compile_sizes = list(set(self.compile_sizes)) + for x in self.compile_sizes: + if isinstance(x, str): + assert x == "cudagraph_capture_sizes", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" + computed_compile_sizes.extend(self.cudagraph_capture_sizes) + else: + assert isinstance(x, int) + computed_compile_sizes.append(x) + self.compile_sizes = computed_compile_sizes # type: ignore + + # sort to make sure cudagraph capture sizes are in descending order + self.cudagraph_capture_sizes.sort(reverse=True) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 + + # pre-compute the mapping from batch size to padded graph size + self.bs_to_padded_graph_size = [ + 0 for i in range(self.max_capture_size + 1) + ] + for end, start in zip(self.cudagraph_capture_sizes, + self.cudagraph_capture_sizes[1:] + [0]): + for bs in range(start, end): + if bs == start: + self.bs_to_padded_graph_size[bs] = start + else: + self.bs_to_padded_graph_size[bs] = end + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size + + def set_splitting_ops_for_v1(self): + # NOTE: this function needs to be called + if self.splitting_ops and self.full_cuda_graph: + raise ValueError("full_cuda_graph cannot be used together with " + "splitting_ops, as Full CUDA graph will override " + f"the splitting_ops: {self.splitting_ops}") + + if not self.splitting_ops: + self.splitting_ops = [] if self.full_cuda_graph else [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + + +@config +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + # TODO: use default_factory once default constructing ModelConfig doesn't + # try to download a model + model_config: ModelConfig = None # type: ignore + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" + lora_config: Optional[LoRAConfig] = None + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" + decoding_config: DecodingConfig = field(default_factory=DecodingConfig) + """Decoding configuration.""" + observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" + prompt_adapter_config: Optional[PromptAdapterConfig] = None + """Prompt adapter configuration.""" + quant_config: Optional[QuantizationConfig] = None + """Quantization configuration.""" + compilation_config: CompilationConfig = field( + default_factory=CompilationConfig) + """`torch.compile` and cudagraph capture configuration for the model. + + As a shorthand, `-O` can be used to directly specify the compilation + level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). + Currently, -O and -O= are supported as well but this will likely be + removed in favor of clearer -O syntax in the future. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production, also default in V1. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" + kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" + # some opaque config, only used to provide additional information + # for the hash computation, mainly used for testing, debugging or out of + # tree config registration. + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" + instance_id: str = "" + """The ID of the vLLM instance.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + factors: list[Any] = [] + + # summarize vllm config + vllm_factors: list[Any] = [] + from vllm import __version__ + vllm_factors.append(__version__) + vllm_factors.append(envs.VLLM_USE_V1) + if self.model_config: + vllm_factors.append(self.model_config.compute_hash()) + else: + vllm_factors.append("None") + if self.cache_config: + vllm_factors.append(self.cache_config.compute_hash()) + else: + vllm_factors.append("None") + if self.parallel_config: + vllm_factors.append(self.parallel_config.compute_hash()) + else: + vllm_factors.append("None") + if self.scheduler_config: + vllm_factors.append(self.scheduler_config.compute_hash()) + else: + vllm_factors.append("None") + if self.device_config: + vllm_factors.append(self.device_config.compute_hash()) + else: + vllm_factors.append("None") + if self.load_config: + vllm_factors.append(self.load_config.compute_hash()) + else: + vllm_factors.append("None") + if self.lora_config: + vllm_factors.append(self.lora_config.compute_hash()) + # LoRA creates static buffers based on max_num_batched_tokens. + # The tensor sizes and strides get captured in the torch.compile + # graph explicitly. + vllm_factors.append( + str(self.scheduler_config.max_num_batched_tokens)) + else: + vllm_factors.append("None") + if self.speculative_config: + vllm_factors.append(self.speculative_config.compute_hash()) + else: + vllm_factors.append("None") + if self.decoding_config: + vllm_factors.append(self.decoding_config.compute_hash()) + else: + vllm_factors.append("None") + if self.observability_config: + vllm_factors.append(self.observability_config.compute_hash()) + else: + vllm_factors.append("None") + if self.prompt_adapter_config: + vllm_factors.append(self.prompt_adapter_config.compute_hash()) + else: + vllm_factors.append("None") + if self.quant_config: + pass # should be captured by model_config.quantization + if self.compilation_config: + vllm_factors.append(self.compilation_config.compute_hash()) + else: + vllm_factors.append("None") + if self.kv_transfer_config: + vllm_factors.append(self.kv_transfer_config.compute_hash()) + else: + vllm_factors.append("None") + if self.additional_config: + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) + else: + vllm_factors.append("None") + factors.append(vllm_factors) + + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def pad_for_cudagraph(self, batch_size: int) -> int: + # if batch_size > self.compilation_config.max_capture_size, + # it should raise an IndexError. + # the caller should make sure the batch_size is within the range, + # i.e., batch_size <= self.compilation_config.max_capture_size + return self.compilation_config.bs_to_padded_graph_size[batch_size] + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + from vllm.platforms import current_platform + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import ( + get_quant_config) + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + @staticmethod + def get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + import copy + + # For some reason, the _ version of this modifies the model_config + # object, so using deepcopy to avoid this problem. + return VllmConfig._get_quantization_config(copy.deepcopy(model_config), + load_config) + + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + + model_config = copy.deepcopy(self.model_config) + model_config.hf_config = hf_config + + return replace(self, model_config=model_config) + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + + self.try_verify_and_update_config() + + if self.model_config is not None: + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config( + self.load_config) + + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config is not None: + self.lora_config.verify_with_cache_config(self.cache_config) + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_lora_support() + if self.prompt_adapter_config is not None: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + if self.quant_config is None and self.model_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config) + + from vllm.platforms import current_platform + if self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + logger.warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + + # async tp is built on top of sequence parallelism + # and requires it to be enabled. + if self.compilation_config.pass_config.enable_async_tp: + self.compilation_config.pass_config.enable_sequence_parallelism = \ + True + if self.compilation_config.pass_config.enable_sequence_parallelism: + self.compilation_config.custom_ops.append("+rms_norm") + if envs.VLLM_USE_V1 and self.model_config is not None and \ + not self.model_config.enforce_eager: + # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph + # is set to True, full CUDA graphs will be used. + self.compilation_config.cudagraph_num_of_warmups = 1 + self.compilation_config.level = CompilationLevel.PIECEWISE + self.compilation_config.set_splitting_ops_for_v1() + + self._set_cudagraph_sizes() + + if self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION \ + and not envs.VLLM_USE_V1: + logger.warning( + "CPU offload is not supported with `torch.compile` in v0 yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if ((not envs.VLLM_USE_V1) and self.lora_config is not None + and self.compilation_config.level + != CompilationLevel.NO_COMPILATION): + logger.warning( + "LoRA for V0 is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if self.compilation_config.full_cuda_graph and \ + not self.model_config.disable_cascade_attn: + logger.info("full_cuda_graph is not supported with " + "cascade attention. Disabling cascade attention.") + self.model_config.disable_cascade_attn = True + + disable_chunked_prefill_reasons: list[str] = [] + + if self.model_config and self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + + if disable_chunked_prefill_reasons: + for reason in disable_chunked_prefill_reasons: + logger.info(reason) + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.long_prefill_token_threshold = 0 + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + + if (self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching): + logger.warning( + "KV cache events are on, but prefix caching is not enabled." + "Use --enable-prefix-caching to enable.") + if (self.kv_events_config is not None + and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events): + logger.warning("KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable.") + current_platform.check_and_update_config(self) + + if not self.instance_id: + self.instance_id = random_uuid()[:5] + + if (envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager): + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not (current_platform.is_cuda() or current_platform.is_rocm()): + # Hybrid KV cache manager is not supported on non-GPU platforms. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_transfer_config is not None: + # Hybrid KV cache manager is not compatible with KV transfer. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + self.scheduler_config.disable_hybrid_kv_cache_manager = True + + def update_sizes_for_sequence_parallelism(self, + possible_sizes: list) -> list: + # remove the sizes that not multiple of tp_size when + # enable sequence parallelism + removed_sizes = [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size != 0 + ] + if removed_sizes: + logger.warning( + "Batch sizes %s are removed because they are not " + "multiple of tp_size %d when " + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) + + return [ + size for size in possible_sizes + if size % self.parallel_config.tensor_parallel_size == 0 + ] + + def _set_cudagraph_sizes(self): + """ + cudagraph batchsize padding logic: + + `[1, 2, 4] + [8 * i for i in range(1, 1025)]` is a list of all possible + batch sizes that cudagraph will capture. + + Depending on the engine's configuration of `max_num_seqs`, the + candidate batch sizes to capture cudagraph will shrink to the subset + which just cover the range of `[1, max_num_seqs]`. In the common case, + `max_num_seqs` is 256, and the cudagraph batch sizes will be + `[1, 2, 4, 8, 16, 24, 32, 40, ..., 256]`. + + However, if users specify the cudagraph capture sizes through + compilation config, we will use the specified sizes instead. + + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). + + During runtime, if batchsize is larger than + `vllm_config.compilation_config.cudagraph_capture_sizes`, + no cudagraph will be used. + If the batch size is no larger than + `vllm_config.compilation_config.cudagraph_capture_sizes`, + we can quickly find the padded graph size for a given batch size by + looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. + """ + + # calculate the default `batch_size_capture_list` + if not envs.VLLM_USE_V1: + batch_size_capture_list = [] + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + possible_sizes = self.update_sizes_for_sequence_parallelism( + possible_sizes) + + # find the minimum size that is larger than max_num_seqs, + # which then becomes the max_batchsize_to_capture + larger_sizes = [ + x for x in possible_sizes + if x >= self.scheduler_config.max_num_seqs + ] + if larger_sizes: + max_batchsize_to_capture = larger_sizes[0] + else: + max_batchsize_to_capture = possible_sizes[-1] + + # filter out the sizes that are + # larger than max_batchsize_to_capture + batch_size_capture_list = [ + size for size in possible_sizes + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + if self.model_config.use_mla and self.compilation_config.full_cuda_graph and self.scheduler_config.max_num_seqs<=512: + cuda_graph_sizes = [self.scheduler_config.max_num_seqs] + else: + cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes + if len(cuda_graph_sizes) == 1: + batch_size_capture_list = [1, 2, 4] + [ + i for i in range(8, cuda_graph_sizes[0] + 1, 8) + ] + elif len(cuda_graph_sizes) > 1: + batch_size_capture_list = sorted(cuda_graph_sizes) + else: + raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + max_num_tokens = self.scheduler_config.max_num_batched_tokens + batch_size_capture_list = [ + size for size in batch_size_capture_list + if size <= max_num_tokens + ] + + # add for spec decode + if self.speculative_config is not None and self.speculative_config.num_lookahead_slots > 0: + batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots), + batch_size_capture_list)) + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + + def recalculate_max_model_len(self, max_model_len: int): + # Can only be called in try_verify_and_update_config + model_config = self.model_config + max_model_len = model_config.get_and_verify_max_len(max_model_len) + self.model_config.max_model_len = max_model_len + self.scheduler_config.max_model_len = max_model_len + + def try_verify_and_update_config(self): + architecture = getattr(self.model_config, "architecture", None) + if architecture is None: + return + + from vllm.model_executor.models.config import MODELS_CONFIG_MAP + cls = MODELS_CONFIG_MAP.get(architecture, None) + if cls is not None: + cls.verify_and_update_config(self) + + if self.model_config.task == "classify": + # Maybe convert ForCausalLM into ForSequenceClassification model. + from vllm.model_executor.models.adapters import ( + SequenceClassificationConfig) + SequenceClassificationConfig.verify_and_update_config(self) + + def __str__(self): + return ( + f"model={self.model_config.model!r}," + f" speculative_config={self.speculative_config!r}," + f" tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}," + f" tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"override_neuron_config={self.model_config.override_neuron_config}," + f" tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}," + f" download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," + f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f" device_config={self.device_config.device}, " + f"decoding_config={self.decoding_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " + f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"pooler_config={self.model_config.pooler_config!r}, " + f"compilation_config={self.compilation_config!r}") + + +_current_vllm_config: Optional[VllmConfig] = None +_current_prefix: Optional[str] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig, + check_compile=False, + prefix: Optional[str] = None): + """ + Temporarily set the current vLLM config. + Used during model initialization. + We save the current vLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the vLLM config to determine how to dispatch. + """ + global _current_vllm_config, _current_prefix + old_vllm_config = _current_vllm_config + old_prefix = _current_prefix + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + _current_prefix = prefix + yield + except Exception: + raise + else: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if check_compile and \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + " if you want it to be supported.", + vllm_config.model_config.model) + finally: + _current_vllm_config = old_vllm_config + _current_prefix = old_prefix + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current vLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() + return _current_vllm_config + + +def get_current_model_prefix() -> str: + """ + Get the prefix of the model that's currently being initialized. + """ + assert _current_prefix is not None, \ + "Current model prefix is not set. " + return _current_prefix + + +def contains_object_print(text): + """ + Check if the text looks like a printed Python object, e.g. + contains any substring matching the pattern: "at 0xFFFFFFF>" + We match against 0x followed by 2-16 hex chars (there's + a max of 16 on a 64 bit system). + + Args: + text (str): The text to check + + Returns: + result (bool): `True` if a match is found, `False` otherwise. + """ + pattern = r'at 0x[a-fA-F0-9]{2,16}>' + match = re.search(pattern, text) + return match is not None + + +def assert_hashable(text): + if not contains_object_print(text): + return True + raise AssertionError( + f"vLLM tried to hash some configs that may have Python objects ids " + f"in them. This is a bug, please file an issue. " + f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/connections.py b/vllm/connections.py new file mode 100644 index 0000000..103505e --- /dev/null +++ b/vllm/connections.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping, MutableMapping +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +import aiohttp +import requests + +from vllm.version import __version__ as VLLM_VERSION + + +class HTTPConnection: + """Helper class to send HTTP requests.""" + + def __init__(self, *, reuse_client: bool = True) -> None: + super().__init__() + + self.reuse_client = reuse_client + + self._sync_client: Optional[requests.Session] = None + self._async_client: Optional[aiohttp.ClientSession] = None + + def get_sync_client(self) -> requests.Session: + if self._sync_client is None or not self.reuse_client: + self._sync_client = requests.Session() + + return self._sync_client + + # NOTE: We intentionally use an async function even though it is not + # required, so that the client is only accessible inside async event loop + async def get_async_client(self) -> aiohttp.ClientSession: + if self._async_client is None or not self.reuse_client: + self._async_client = aiohttp.ClientSession(trust_env=True) + + return self._async_client + + def _validate_http_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.scheme not in ("http", "https"): + raise ValueError("Invalid HTTP URL: A valid HTTP URL " + "must have scheme 'http' or 'https'.") + + def _headers(self, **extras: str) -> MutableMapping[str, str]: + return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} + + def get_response( + self, + url: str, + *, + stream: bool = False, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = self.get_sync_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout) + + async def get_async_response( + self, + url: str, + *, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = await self.get_async_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + timeout=timeout) + + def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.content + + async def async_get_bytes( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> bytes: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.read() + + def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.text + + async def async_get_text( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.text() + + def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.json() + + async def async_get_json( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.json() + + def download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) + + return save_path + + async def async_download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) + + return save_path + + +global_http_connection = HTTPConnection() +""" +The global [`HTTPConnection`][vllm.connections.HTTPConnection] instance used +by vLLM. +""" diff --git a/vllm/core/__init__.py b/vllm/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/block/__init__.py b/vllm/core/block/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py new file mode 100644 index 0000000..444bb25 --- /dev/null +++ b/vllm/core/block/block_table.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import List, Optional + +from vllm.core.block.common import BlockList +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.utils import Device, cdiv, chunk_list + + +class BlockTable: + """A class to manage blocks for a specific sequence. + + The BlockTable maps a sequence of tokens to a list of blocks, where each + block represents a contiguous memory allocation for a portion of the + sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is + responsible for allocating and freeing memory for the blocks. + + Args: + block_size (int): The maximum number of tokens that can be stored in a + single block. + block_allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]], optional): An optional list of existing + blocks to initialize the BlockTable with. If not provided, an empty + BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequence. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. + + Attributes: + _block_size (int): The maximum number of tokens that can be stored in a + single block. + _allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]]): The list of blocks managed by this + BlockTable. + _num_full_slots (int): The number of tokens currently stored in the + blocks. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, + ): + self._block_size = block_size + self._allocator = block_allocator + if _blocks is None: + _blocks = [] + self._blocks: BlockList = BlockList(_blocks) + + self._max_block_sliding_window = max_block_sliding_window + self._num_full_slots = self._get_num_token_ids() + + @staticmethod + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: + """Calculates the minimum number of blocks required to store a given + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). + + This assumes worst-case scenario, where every block requires a new + allocation (e.g. ignoring prefix caching). + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + block_size (int): The maximum number of tokens that can be stored in + a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. + + Returns: + int: The minimum number of blocks required to store the given + sequence of token IDs along with any required look-ahead slots. + """ + return cdiv(len(token_ids) + num_lookahead_slots, block_size) + + def allocate(self, + token_ids: List[int], + device: Device = Device.GPU, + extra_hash: Optional[int] = None) -> None: + """Allocates memory blocks for storing the given sequence of token IDs. + + This method allocates the required number of blocks to store the given + sequence of token IDs. + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + device (Device, optional): The device on which the blocks should be + allocated. Defaults to Device.GPU. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefixcaching block. + """ + assert not self._is_allocated + assert token_ids + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + self.update(blocks) + self._num_full_slots = len(token_ids) + + def update(self, blocks: List[Block]) -> None: + """Resets the table to the newly provided blocks + (with their corresponding block ids) + """ + self._blocks.update(blocks) + + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None, + extra_hash: Optional[int] = None) -> None: + """Appends a sequence of token IDs to the existing blocks in the + BlockTable. + + This method appends the given sequence of token IDs to the existing + blocks in the BlockTable. If there is not enough space in the existing + blocks, new blocks are allocated using the `ensure_num_empty_slots` + method to accommodate the additional tokens. + + The token IDs are divided into chunks of size `block_size` (except for + the first chunk, which may be smaller), and each chunk is appended to a + separate block. + + Args: + token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + assert self._is_allocated, "no blocks have been allocated" + assert len(self._blocks) > 0 + + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots, + extra_hash=extra_hash) + + # Update the blocks with the new tokens + first_block_idx = self._num_full_slots // self._block_size + token_blocks = self._chunk_token_blocks_for_append(token_ids) + + for i, token_block in enumerate(token_blocks): + self._blocks.append_token_ids(first_block_idx + i, token_block) + + self._num_full_slots += len(token_ids) + + def ensure_num_empty_slots(self, + num_empty_slots: int, + extra_hash: Optional[int] = None) -> None: + """Ensures that the BlockTable has at least the specified number of + empty slots available. + + This method checks if the BlockTable has enough empty slots (i.e., + available space) to accommodate the requested number of tokens. If not, + it allocates additional blocks on the GPU to ensure that the required + number of empty slots is available. + + Args: + num_empty_slots (int): The minimum number of empty slots required. + extra_hash (Optional[int]): The hash value of additional + factors such as adapters that influence the block, apart + from the token_ids. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + device = Device.GPU + assert self._is_allocated + + if self._num_empty_slots >= num_empty_slots: + return + + slots_to_allocate = num_empty_slots - self._num_empty_slots + blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) + + for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 + self._blocks.append( + self._allocator.allocate_mutable_block( + prev_block=self._blocks[-1], + device=device, + extra_hash=extra_hash)) + + def fork(self) -> "BlockTable": + """Creates a new BlockTable instance with a copy of the blocks from the + current instance. + + This method creates a new BlockTable instance with the same block size, + block allocator, and a copy of the blocks from the current instance. The + new BlockTable has its own independent set of blocks, but shares the + same underlying memory allocation with the original BlockTable. + + Returns: + BlockTable: A new BlockTable instance with a copy of the blocks from + the current instance. + """ + assert self._is_allocated + assert len(self._blocks) > 0 + forked_blocks = self._allocator.fork(self._blocks[-1]) + return BlockTable( + block_size=self._block_size, + block_allocator=self._allocator, + _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, + ) + + def free(self) -> None: + """Frees the memory occupied by the blocks in the BlockTable. + + This method iterates over all the blocks in the `_blocks` list and calls + the `free` method of the `_allocator` object to release the memory + occupied by each block. After freeing all the blocks, the `_blocks` list + is set to `None`. + """ + for block in self.blocks: + self._allocator.free(block) + self._blocks.reset() + + @property + def physical_block_ids(self) -> List[int]: + """Returns a list of physical block indices for the blocks in the + BlockTable. + + This property returns a list of integers, where each integer represents + the physical block index of a corresponding block in the `_blocks` list. + The physical block index is a unique identifier for the memory location + occupied by the block. + + Returns: + List[int]: A list of physical block indices for the blocks in the + BlockTable. + """ + return self._blocks.ids() + + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return sequence_token_ids[self.num_full_slots:] + + def _allocate_blocks_for_token_ids( + self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + blocks: List[Block] = [] + + block_token_ids = [] + tail_token_ids = [] + for cur_token_ids in chunk_list(token_ids, self._block_size): + if len(cur_token_ids) == self._block_size: + block_token_ids.append(cur_token_ids) + else: + tail_token_ids.append(cur_token_ids) + + if block_token_ids: + blocks.extend( + self._allocator.allocate_immutable_blocks( + prev_block, + block_token_ids=block_token_ids, + device=device, + extra_hash=extra_hash)) + prev_block = blocks[-1] + + if tail_token_ids: + assert len(tail_token_ids) == 1 + cur_token_ids = tail_token_ids[0] + + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device, extra_hash=extra_hash) + block.append_token_ids(cur_token_ids) + + blocks.append(block) + + return blocks + + def _get_all_token_ids(self) -> List[int]: + # NOTE: This function is O(seq_len); use sparingly. + token_ids: List[int] = [] + + if not self._is_allocated: + return token_ids + + for block in self.blocks: + token_ids.extend(block.token_ids) + + return token_ids + + def _get_num_token_ids(self) -> int: + res = 0 + for block in self.blocks: + res += len(block.token_ids) + + return res + + @property + def _is_allocated(self) -> bool: + return len(self._blocks) > 0 + + @property + def blocks(self) -> List[Block]: + return self._blocks.list() + + @property + def _num_empty_slots(self) -> int: + assert self._is_allocated + return len(self._blocks) * self._block_size - self._num_full_slots + + @property + def num_full_slots(self) -> int: + """Returns the total number of tokens currently stored in the + BlockTable. + + Returns: + int: The total number of tokens currently stored in the BlockTable. + """ + return self._num_full_slots + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + # Math below is equivalent to: + # all_token_ids = token_ids + [-1] * num_lookahead_slots + # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) + # return len(token_blocks) + + num_token_ids = len(token_ids) + num_lookahead_slots + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + num_token_blocks = (1 + math.ceil( + (num_token_ids - first_chunk_size) / self._block_size)) + return num_token_blocks + + def _chunk_token_blocks_for_append( + self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + + If no token ids are provided, then no chunks are returned. + """ + + if not token_ids: + return [] + + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + token_blocks.extend( + chunk_list(token_ids[first_chunk_size:], self._block_size)) + return token_blocks diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py new file mode 100644 index 0000000..a337007 --- /dev/null +++ b/vllm/core/block/common.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import deque +from dataclasses import dataclass +from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple + +from vllm.core.block.interfaces import Block, BlockAllocator + +BlockId = int +RefCount = int + + +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): + """A class for managing reference counts for a set of block indices. + + The RefCounter class maintains a dictionary that maps block indices to their + corresponding reference counts. It provides methods to increment, decrement, + and retrieve the reference count for a given block index. + + Args: + all_block_indices (Iterable[BlockId]): An iterable of block indices + to initialize the reference counter with. + """ + + def __init__(self, all_block_indices: Iterable[BlockId]): + deduped = set(all_block_indices) + self._refcounts: Dict[BlockId, RefCount] = { + index: 0 + for index in deduped + } + + def incr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + pre_incr_refcount = self._refcounts[block_id] + + assert pre_incr_refcount >= 0 + + post_incr_refcount = pre_incr_refcount + 1 + self._refcounts[block_id] = post_incr_refcount + return post_incr_refcount + + def decr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + refcount = self._refcounts[block_id] + + assert refcount > 0 + refcount -= 1 + + self._refcounts[block_id] = refcount + + return refcount + + def get(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + return self._refcounts[block_id] + + def as_readonly(self) -> "ReadOnlyRefCounter": + return ReadOnlyRefCounter(self) + + +class ReadOnlyRefCounter(RefCounterProtocol): + """A read-only view of the RefCounter class. + + The ReadOnlyRefCounter class provides a read-only interface to access the + reference counts maintained by a RefCounter instance. It does not allow + modifications to the reference counts. + + Args: + refcounter (RefCounter): The RefCounter instance to create a read-only + view for. + """ + + def __init__(self, refcounter: RefCounter): + self._refcounter = refcounter + + def incr(self, block_id: BlockId) -> RefCount: + raise ValueError("Incr not allowed") + + def decr(self, block_id: BlockId) -> RefCount: + raise ValueError("Decr not allowed") + + def get(self, block_id: BlockId) -> RefCount: + return self._refcounter.get(block_id) + + +class CopyOnWriteTracker: + """A class for tracking and managing copy-on-write operations for blocks. + + The CopyOnWriteTracker class maintains a mapping of source block indices to + their corresponding copy-on-write destination block indices. It works in + conjunction with a RefCounter. + + Args: + refcounter (RefCounter): The reference counter used to track block + reference counts. + """ + + def __init__(self, refcounter: RefCounterProtocol): + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] + self._refcounter = refcounter + + def is_appendable(self, block: Block) -> bool: + """Checks if the block is shared or not. If shared, then it cannot + be appended and needs to be duplicated via copy-on-write + """ + block_id = block.block_id + if block_id is None: + return True + + refcount = self._refcounter.get(block_id) + return refcount <= 1 + + def record_cow(self, src_block_id: Optional[BlockId], + trg_block_id: Optional[BlockId]) -> None: + """Records a copy-on-write operation from source to target block id + Args: + src_block_id (BlockId): The source block id from which to copy + the data + trg_block_id (BlockId): The target block id to which the data + is copied + """ + assert src_block_id is not None + assert trg_block_id is not None + self._copy_on_writes.append((src_block_id, trg_block_id)) + + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: + """Clears the copy-on-write tracking information and returns the current + state. + + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. + It then clears the internal tracking information. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the + current copy-on-write operations. + """ + cows = self._copy_on_writes + self._copy_on_writes = [] + return cows + + +class BlockPool: + """Used to pre-allocate block objects, in order to avoid excessive python + object allocations/deallocations. + The pool starts from "pool_size" objects and will increase to more objects + if necessary + + Note that multiple block objects may point to the same physical block id, + which is why this pool is needed, so that it will be easier to support + prefix caching and more complicated sharing of physical blocks. + """ + + def __init__(self, block_size: int, create_block: Block.Factory, + allocator: BlockAllocator, pool_size: int): + self._block_size = block_size + self._create_block = create_block + self._allocator = allocator + self._pool_size = pool_size + assert self._pool_size >= 0 + + self._free_ids: Deque[int] = deque(range(self._pool_size)) + self._pool = [] + for i in range(self._pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None, + extra_hash=None)) + + def increase_pool(self): + """Doubles the internal pool size + """ + cur_pool_size = self._pool_size + new_pool_size = cur_pool_size * 2 + self._pool_size = new_pool_size + + self._free_ids += deque(range(cur_pool_size, new_pool_size)) + + for i in range(cur_pool_size, new_pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None, + extra_hash=None)) + + def init_block(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + physical_block_id: Optional[int], + extra_hash: Optional[int] = None) -> Block: + if len(self._free_ids) == 0: + self.increase_pool() + assert len(self._free_ids) > 0 + + pool_id = self._free_ids.popleft() + + block = self._pool[pool_id] + block.__init__( # type: ignore[misc] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + allocator=block._allocator, # type: ignore[attr-defined] + block_id=physical_block_id, + extra_hash=extra_hash) + block.pool_id = pool_id # type: ignore[attr-defined] + return block + + def free_block(self, block: Block) -> None: + self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] + + +class BlockList: + """This class is an optimization to allow fast-access to physical + block ids. It maintains a block id list that is updated with the + block list and this avoids the need to reconstruct the block id + list on every iteration of the block manager + """ + + def __init__(self, blocks: List[Block]): + self._blocks: List[Block] = [] + self._block_ids: List[int] = [] + + self.update(blocks) + + def _add_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_ids.append(block_id) + + def _update_block_id(self, block_index: int, + new_block_id: Optional[BlockId]) -> None: + assert new_block_id is not None + self._block_ids[block_index] = new_block_id + + def update(self, blocks: List[Block]): + self._blocks = blocks + + # Cache block ids for fast query + self._block_ids = [] + for block in self._blocks: + self._add_block_id(block.block_id) + + def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + block = self._blocks[block_index] + prev_block_id = block.block_id + + block.append_token_ids(token_ids) + + # CoW or promotion may update the internal block_id + if prev_block_id != block.block_id: + self._update_block_id(block_index, block.block_id) + + def append(self, new_block: Block): + self._blocks.append(new_block) + self._add_block_id(new_block.block_id) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, block_index: int) -> Block: + return self._blocks[block_index] + + def __setitem__(self, block_index: int, new_block: Block) -> None: + self._blocks[block_index] = new_block + self._update_block_id(block_index, new_block.block_id) + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def list(self) -> List[Block]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids + + +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + +def get_all_blocks_recursively(last_block: Block) -> List[Block]: + """Retrieves all the blocks in a sequence starting from the last block. + + This function recursively traverses the sequence of blocks in reverse order, + starting from the given last block, and returns a list of all the blocks in + the sequence. + + Args: + last_block (Block): The last block in the sequence. + + Returns: + List[Block]: A list of all the blocks in the sequence, in the order they + appear. + """ + + def recurse(block: Block, lst: List[Block]) -> None: + if block.prev_block is not None: + recurse(block.prev_block, lst) + lst.append(block) + + all_blocks: List[Block] = [] + recurse(last_block, all_blocks) + return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py new file mode 100644 index 0000000..ea490c3 --- /dev/null +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Dict, FrozenSet, List, Optional, Tuple + +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.platforms import current_platform +from vllm.utils import Device + + +class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): + """A block allocator that can allocate blocks on both CPU and GPU memory. + + This class implements the `DeviceAwareBlockAllocator` interface and provides + functionality for allocating and managing blocks of memory on both CPU and + GPU devices. + + The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU + blocks, and allows for allocation, deallocation, forking, and swapping of + blocks across these memory pools. + """ + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Creates a CpuGpuBlockAllocator instance with the specified + configuration. + + This static method creates and returns a CpuGpuBlockAllocator instance + based on the provided parameters. It initializes the CPU and GPU block + allocators with the specified number of blocks, block size, and + allocator type. + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the + specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + # For HPU, block id 0 is used only for padding + reserved_blocks = 1 if current_platform.is_hpu() else 0 + block_ids = list( + range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) + num_gpu_blocks -= reserved_blocks + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + if allocator_type == "naive": + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + elif allocator_type == "prefix_caching": + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + else: + raise ValueError(f"Unknown allocator type {allocator_type=}") + + return CpuGpuBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__(self, cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + self._allocators = { + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator, + } + + self._swap_mapping: Dict[int, int] = {} + self._null_block: Optional[Block] = None + + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} + for _, allocator in self._allocators.items(): + for block_id in allocator.all_block_ids: + self._block_ids_to_allocator[block_id] = allocator + + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable_block(None, Device.GPU)) + return self._null_block + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + Block: The newly allocated mutable block. + """ + return self._allocators[device].allocate_mutable_block( + prev_block, extra_hash=extra_hash) + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + return self._allocators[device].allocate_immutable_blocks( + prev_block, block_token_ids, extra_hash=extra_hash) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + extra_hash (Optional[int]): The hash value of additional + factors, such as adapters, that influence the block hash + in the prefix caching block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + return self._allocators[device].allocate_immutable_block( + prev_block, token_ids, extra_hash=extra_hash) + + def free(self, block: Block) -> None: + """Frees the memory occupied by the given block. + + Args: + block (Block): The block to be freed. + """ + # Null block should never be freed + if isinstance(block, NullBlock): + return + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + allocator.free(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: A new list of blocks that shares the same memory as the + original sequence. + """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + return allocator.fork(last_block) + + def get_num_free_blocks(self, device: Device) -> int: + """Returns the number of free blocks available on the specified device. + + Args: + device (Device): The device for which to query the number of free + blocks. AssertionError is raised if None is passed. + + Returns: + int: The number of free blocks available on the specified device. + """ + return self._allocators[device].get_num_free_blocks() + + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + """Returns the zero-offset block id on certain device given the + absolute block id. + + Args: + device (Device): The device for which to query relative block id. + absolute_id (int): The absolute block id for the block in + whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return self._allocators[device].get_physical_block_id(absolute_id) + + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + """Execute the swap for the given blocks from source_device + on to dest_device, save the current swap mapping and append + them to the accumulated `self._swap_mapping` for each + scheduling move. + + Args: + blocks: List of blocks to be swapped. + src_device (Device): Device to swap the 'blocks' from. + dst_device (Device): Device to swap the 'blocks' to. + + Returns: + Dict[int, int]: Swap mapping from source_device + on to dest_device. + """ + src_block_ids = [block.block_id for block in blocks] + self._allocators[src_device].swap_out(blocks) + self._allocators[dst_device].swap_in(blocks) + dst_block_ids = [block.block_id for block in blocks] + + current_swap_mapping: Dict[int, int] = {} + for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): + if src_block_id is not None and dst_block_id is not None: + self._swap_mapping[src_block_id] = dst_block_id + current_swap_mapping[src_block_id] = dst_block_id + return current_swap_mapping + + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + + Args: + blocks: List of blocks to be swapped. + device (Device): Device to swap the 'blocks' on. + + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + Non full blocks are ignored when deciding the number + of blocks to touch. + """ + return self._allocators[device].get_num_full_blocks_touched(blocks) + + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + """Clears the copy-on-write (CoW) state and returns the mapping of + source to destination block IDs. + + Returns: + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. + """ + # CoW only supported on GPU + device = Device.GPU + return self._allocators[device].clear_copy_on_writes() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_common_computed_block_ids( + computed_seq_block_ids) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return frozenset(self._block_ids_to_allocator.keys()) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + if device: + return self._allocators[device].reset_prefix_cache() + success = True + for allocator in self._allocators.values(): + success = success and allocator.reset_prefix_cache() + return success + + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Returns: + List[Tuple[int, int]]: A mapping of source to destination block IDs. + """ + mapping = self._swap_mapping.copy() + self._swap_mapping.clear() + return list(mapping.items()) + + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + return self._allocators[device].find_cached_blocks_prefix(block_hashes) + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for null block") + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def extra_hash(self): + return None + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py new file mode 100644 index 0000000..1a05881 --- /dev/null +++ b/vllm/core/block/interfaces.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple + +from vllm.utils import Device + +BlockId = int + + +class Block(ABC): + + @abstractmethod + def append_token_ids(self, token_ids: List[int]) -> None: + pass + + @property + @abstractmethod + def block_id(self) -> Optional[int]: + pass + + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + + @property + @abstractmethod + def token_ids(self) -> List[int]: + pass + + @property + @abstractmethod + def num_tokens_total(self) -> int: + """The number of tokens till the current block (inclusive) + """ + pass + + @property + @abstractmethod + def num_empty_slots(self) -> int: + pass + + @property + @abstractmethod + def is_full(self) -> bool: + pass + + @property + @abstractmethod + def prev_block(self) -> Optional["Block"]: + pass + + @property + @abstractmethod + def extra_hash(self) -> Optional[int]: + return None + + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + class Factory(Protocol): + + @abstractmethod + def __call__( + self, + prev_block: Optional["Block"], + token_ids: List[int], + block_size: int, + allocator: "BlockAllocator", + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ) -> "Block": + pass + + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + + +class BlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, prev_block: Optional[Block], + extra_hash: Optional[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int]) -> List[Block]: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def get_physical_block_id(self, absolute_id: int) -> int: + pass + + @abstractmethod + def swap_out(self, blocks: List[Block]) -> None: + pass + + @abstractmethod + def swap_in(self, blocks: List[Block]) -> None: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self) -> bool: + """Reset prefix cache.""" + pass + + class NoFreeBlocksError(ValueError): + pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + ) -> List[int]: + pass + + +class DeviceAwareBlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Device, + extra_hash: Optional[int] = None) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Device, + extra_hash: Optional[int] = None) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + extra_hash: Optional[int] = None, + ) -> List[Block]: + pass + + @abstractmethod + def get_num_free_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + pass + + @abstractmethod + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + pass + + @abstractmethod + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache.""" + pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py new file mode 100644 index 0000000..dae6ead --- /dev/null +++ b/vllm/core/block/naive_block.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import deque +from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union + +from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device + +Refcount = int + + +class NaiveBlockAllocator(BlockAllocator): + """A simple block allocator that manages blocks of memory without prefix + caching. + + Args: + create_block (Block.Factory): A factory function for creating new + blocks. This is used when a NaiveBlockAllocator is composed within + a prefix caching allocator -- the naive block allocator must + construct prefix caching blocks (but shouldn't know anything else + about them). + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids (Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + def __init__( + self, + create_block: Block.Factory, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + block_pool: Optional[BlockPool] = None, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._free_block_indices: Deque[BlockId] = deque(block_ids) + self._all_block_indices = frozenset(block_ids) + assert len(self._all_block_indices) == num_blocks + + self._refcounter = RefCounter( + all_block_indices=self._free_block_indices) + self._block_size = block_size + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + if block_pool is None: + extra_factor = 4 + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + self._block_pool = BlockPool(self._block_size, create_block, self, + num_blocks * extra_factor) + else: + # In this case, the block pool is provided by the caller, + # which means that there is most likely a need to share + # a block pool between allocators + self._block_pool = block_pool + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a new immutable block with the given token IDs, linked to + the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + token_ids (List[int]): The token IDs to be stored in the new block. + + Returns: + Block: The newly allocated immutable block. + """ + assert device is None + block = self.allocate_mutable_block(prev_block=prev_block) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> List[Block]: + assert device is None + num_blocks = len(block_token_ids) + + block_ids = [] + for i in range(num_blocks): + block_ids.append(self._allocate_block_id()) + + blocks = [] + for i in range(num_blocks): + prev_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block_token_ids[i], + block_size=self._block_size, + physical_block_id=block_ids[i]) + blocks.append(prev_block) + + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a new mutable block, linked to the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + + Returns: + Block: The newly allocated mutable block. + """ + assert device is None + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + return block + + def _allocate_block_id(self) -> BlockId: + if not self._free_block_indices: + raise BlockAllocator.NoFreeBlocksError() + + block_id = self._free_block_indices.popleft() + self._refcounter.incr(block_id) + return block_id + + def _free_block_id(self, block: Union[Block, BlockId]) -> None: + if isinstance(block, Block): + block_id = block.block_id + block.block_id = None + else: + block_id = block + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount == 0: + self._free_block_indices.appendleft(block_id) + + def free(self, block: Block, keep_block_object: bool = False) -> None: + # Release the physical block id + self._free_block_id(block) + + # Release the block object + if not keep_block_object: + self._block_pool.free_block(block) + + def free_block_id(self, block_id: BlockId) -> None: + self._free_block_id(block_id) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + + # Increment refcount for each block. + assert block.block_id is not None + refcount = self._refcounter.incr(block.block_id) + assert refcount != 1, "can't fork free'd block" + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block.block_id) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self) -> int: + return len(self._free_block_indices) + + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return sorted(self._all_block_indices).index(absolute_id) + + @property + def refcounter(self): + return self._refcounter + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._all_block_indices + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as computed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Determine blocks that can be skipped in prefill. + + Since the naive allocator does not support prefix caching, always return + an empty list. + """ + return [] + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError("There is no promotion for naive blocks") + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + # NOTE: for naive block, we use set to eliminate common blocks among + # seqs, also we compare the empty slots in the mutable blocks with + # lookahead slots to get the number of unique new block that are + # needed. + old_block_set = set() + for block in blocks: + if block.is_full: + old_block_set.add(block) + return len(old_block_set) + + def swap_out(self, blocks: List[Block]) -> None: + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + tmp_block.block_id = None + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def get_prefix_cache_hit_rate(self) -> float: + return -1 + + def reset_prefix_cache(self) -> bool: + """No prefix cache for naive block allocator.""" + return True + + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + # Not applicable for naive block allocator. + return [] + + +class NaiveBlock(Block): + """An implementation of the Block class that does not support prefix + caching. + + The NaiveBlock class represents a block of token IDs with a fixed size. It + provides methods for appending token IDs to the block and manages copy-on + -write operations when necessary. + + Args: + prev_block (Block): The previous block in the sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The block allocator associated with this + block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None, which means no allocation has been + made. + _cow_target (Optional[Block], optional): The copy-on-write target block. + If not provided, it defaults to self. + """ + + def __init__(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + _cow_target: Optional[Block] = None, + extra_hash: Optional[int] = None): + self._token_ids: List[int] = [] + self._block_size = block_size + self._prev_block = prev_block + self._block_id = block_id + self._allocator = allocator + self._cow_target = _cow_target if _cow_target is not None else self + + self._append_token_ids_no_cow(token_ids) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and performs a + copy-on-write if necessary. + + Args: + token_ids (Optional[List[int]]): The token IDs to be appended + to the block. + """ + self._append_token_ids_no_cow(token_ids) + + if self._block_id is not None: + self._block_id = (self._allocator.cow_block_if_not_appendable( + self._cow_target)) + + def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + if len(token_ids) == 0: + return + + assert len(token_ids) <= self.num_empty_slots + + self._token_ids.extend(token_ids) + + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + @property + def block_id(self) -> Optional[int]: + return self._block_id + + @block_id.setter + def block_id(self, value: Optional[int]) -> None: + self._block_id = value + + @property + def is_full(self) -> bool: + return self.num_empty_slots == 0 + + @property + def num_empty_slots(self) -> int: + return self._block_size - len(self.token_ids) + + @property + def token_ids(self) -> List[int]: + return self._token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for naive block") + + @property + def block_size(self) -> int: + return self._block_size + + @property + def prev_block(self) -> Optional["Block"]: + return self._prev_block + + @property + def extra_hash(self): + return None + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py new file mode 100644 index 0000000..2913a01 --- /dev/null +++ b/vllm/core/block/prefix_caching_block.py @@ -0,0 +1,1135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Token blocks.""" +import sys +from bisect import bisect_left +from os.path import commonprefix +from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, + Tuple) + +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, + get_all_blocks_recursively) +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import (BlockPool, NaiveBlock, + NaiveBlockAllocator) +from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.logger import init_logger +from vllm.sequence import Sequence + +PrefixHash = int + +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 + +logger = init_logger(__name__) + + +class BlockTracker: + """Used to track the status of a block inside the prefix caching allocator + """ + __slots__ = ("active", "last_accessed", "computed") + + def reset(self): + self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self.computed: bool = False + + def __init__(self): + self.active: bool = False + self.reset() + + def enable(self): + assert not self.active + self.active = True + self.reset() + + def disable(self): + assert self.active + self.active = False + self.reset() + + +class PrefixCachingBlockAllocator(BlockAllocator): + """A block allocator that implements prefix caching. + + The PrefixCachingBlockAllocator maintains a cache of blocks based on their + content hash. It reuses blocks with the same content hash to avoid redundant + memory allocation. The allocator also supports copy-on-write operations. + + Args: + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids(Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + # Implements Block.Factory. + def __init__( + self, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._block_size = block_size + + # A mapping of prefix hash to block index. All blocks which have a + # prefix hash will be in this dict, even if they have refcount 0. + self._cached_blocks: Dict[PrefixHash, BlockId] = {} + + # A list of immutable block IDs that have been touched by scheduler + # and should be marked as computed after an entire batch of sequences + # are scheduled. + self._touched_blocks: Set[BlockId] = set() + + # Used to track status of each physical block id + self._block_tracker: Dict[BlockId, BlockTracker] = {} + for block_id in block_ids: + self._block_tracker[block_id] = BlockTracker() + + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + extra_factor = 4 + self._block_pool = BlockPool(self._block_size, self._create_block, + self, num_blocks * extra_factor) + + # An allocator for blocks that do not have prefix hashes. + self._hashless_allocator = NaiveBlockAllocator( + create_block=self._create_block, # type: ignore + num_blocks=num_blocks, + block_size=block_size, + block_ids=block_ids, + block_pool=self._block_pool, # Share block pool here + ) + + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.eviction_policy = eviction_policy + self.evictor: Evictor = make_evictor(self.eviction_policy) + + # We share the refcounter between allocators. This allows us to promote + # blocks originally allocated in the hashless allocator to immutable + # blocks. + self._refcounter = self._hashless_allocator.refcounter + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + self.metric_data = CacheMetricData() + + def _create_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ) -> Block: + # Bind block to self. + allocator = self + + return PrefixCachingBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=allocator, + computed=computed, + extra_hash=extra_hash, + ) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates an immutable block with the given token IDs, reusing cached + blocks if possible. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + token_ids (List[int]): The token IDs to be stored in the block. + + Returns: + Block: The allocated immutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + # First, try to create a block that points to cached data + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + physical_block_id=None, + extra_hash=extra_hash) + assert block.content_hash is not None + + cached_block_id = self._cached_blocks.get(block.content_hash, None) + if cached_block_id is not None: + self.metric_data.query(hit=True) + block.block_id = cached_block_id + self._incr_refcount_cached_block(block) + return block + self.metric_data.query(hit=False) + self._block_pool.free_block(block) + + # No cached block => Allocate a new block + block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> List[Block]: + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device, + extra_hash=extra_hash) + blocks.append(prev_block) + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + extra_hash: Optional[int] = None, + device: Optional[Device] = None) -> Block: + """Allocates a mutable block. If there are no free blocks, this will + evict unused cached blocks. + + Args: + prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. + + Returns: + Block: The allocated mutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id, + extra_hash=extra_hash) + assert not block.computed + assert block.content_hash is None + return block + + def _incr_refcount_cached_block(self, block: Block) -> None: + # Set this block to be "computed" since it is pointing to a + # cached block id (which was already computed) + block.computed = True + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + if refcount == 1: + # In case a cached block was evicted, restore its tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + + self._track_block_id(block_id, computed=True) + + def _decr_refcount_cached_block(self, block: Block) -> None: + # Ensure this is immutable/cached block + assert block.content_hash is not None + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount > 0: + block.block_id = None + return + else: + assert refcount == 0 + + # No longer used + assert block.content_hash in self._cached_blocks + + # Add the cached block to the evictor + # (This keeps the cached block around so it can be reused) + self.evictor.add(block_id, block.content_hash, block.num_tokens_total, + self._block_tracker[block_id].last_accessed) + + # Stop tracking the block + self._untrack_block_id(block_id) + + block.block_id = None + + def _decr_refcount_hashless_block(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None + + # We may have a fork case where block is shared, + # in which case, we cannot remove it from tracking + refcount = self._refcounter.get(block_id) + if refcount == 1: + self._untrack_block_id(block_id) + + # Decrement refcount of the block_id, but do not free the block object + # itself (will be handled by the caller) + self._hashless_allocator.free(block, keep_block_object=True) + + def _allocate_block_id(self) -> BlockId: + """First tries to allocate a block id from the hashless allocator, + and if there are no blocks, then tries to evict an unused cached block. + """ + hashless_block_id = self._maybe_allocate_hashless_block_id() + if hashless_block_id is not None: + return hashless_block_id + + evicted_block_id = self._maybe_allocate_evicted_block_id() + if evicted_block_id is not None: + return evicted_block_id + + # No block available in hashless allocator, nor in unused cache blocks. + raise BlockAllocator.NoFreeBlocksError() + + def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: + try: + # Allocate mutable block and extract its block_id + block = self._hashless_allocator.allocate_mutable_block( + prev_block=None) + block_id = block.block_id + self._block_pool.free_block(block) + + self._track_block_id(block_id, computed=False) + return block_id + except BlockAllocator.NoFreeBlocksError: + return None + + def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: + if self.evictor.num_blocks == 0: + return None + + # Here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list + block_id, content_hash_to_evict = self.evictor.evict() + + # Sanity checks + assert content_hash_to_evict in self._cached_blocks + _block_id = self._cached_blocks[content_hash_to_evict] + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) + + self._refcounter.incr(block_id) + self._track_block_id(block_id, computed=False) + + return block_id + + def _free_block_id(self, block: Block) -> None: + """Decrements the refcount of the block. The block may be in two + possible states: (1) immutable/cached or (2) mutable/hashless. + In the first case, the refcount is decremented directly and the block + may be possibly added to the evictor. In other case, hashless + allocator free(..) with keep_block_object=True is called to only free + the block id (since the block object may be reused by the caller) + """ + block_id = block.block_id + assert block_id is not None, "Freeing unallocated block is undefined" + + if block.content_hash is not None: + # Immutable: This type of block is always cached, and we want to + # keep it in the evictor for future reuse + self._decr_refcount_cached_block(block) + else: + # Mutable: This type of block is not cached, so we release it + # directly to the hashless allocator + self._decr_refcount_hashless_block(block) + + assert block.block_id is None + + def free(self, block: Block, keep_block_object: bool = False) -> None: + """Release the block (look at free_block_id(..) docs) + """ + # Release the physical block index + self._free_block_id(block) + + # Release the block object to the pool + if not keep_block_object: + self._block_pool.free_block(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + assert refcount != 1, "can't fork free'd block_id = {}".format( + block_id) + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block_id, + extra_hash=block.extra_hash) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None + # The number of free blocks is the number of hashless free blocks + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks + + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The rzero-offset block id on certain device. + """ + return sorted(self.all_block_ids).index(absolute_id) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._hashless_allocator.all_block_ids + + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + + def reset_prefix_cache(self) -> bool: + """Reset prefix cache. This function may be used in RLHF + flows to invalid prefix caching after the weights are updated, + or used for resetting prefix caching status for benchmarking. + + Returns: + bool: True if the prefix cache is successfully reset, + False otherwise. + """ + num_used_blocks = (self.get_num_total_blocks() - + self.get_num_free_blocks()) + if num_used_blocks > 0: + logger.warning( + "Failed to reset prefix cache because some " + "blocks (%d) are not freed yet", num_used_blocks) + return False + + # Free all blocks in the evictor. + while (block_id := + self._maybe_allocate_evicted_block_id()) is not None: + self._hashless_allocator.free_block_id(block_id) + + # Should not have any cached blocks because all blocks are evicted. + assert not self._cached_blocks + + # Reset the evictor. + self.evictor = make_evictor(self.eviction_policy) + + # Reset the block tracker. + for block_id in self._block_tracker: + self._block_tracker[block_id] = BlockTracker() + + # Reset the metrics. + self.metric_data = CacheMetricData() + + logger.info("Successfully reset prefix cache") + return True + + def is_block_cached(self, block: Block) -> bool: + assert block.content_hash is not None + return block.content_hash in self._cached_blocks + + def promote_to_immutable_block(self, block: Block) -> BlockId: + """Once a mutable block is full, it can be promoted to an immutable + block. This means that its content can be referenced by future blocks + having the same prefix. + + Note that if we already have a cached block with the same content, we + will replace the newly-promoted block's mapping with the existing cached + block id. + + Args: + block: The mutable block to be promoted. + + Returns: + BlockId: Either the original block index, or the block index of + the previously cached block matching the same content. + """ + # Ensure block can be promoted + assert block.content_hash is not None + assert block.block_id is not None + assert self._refcounter.get(block.block_id) > 0 + + if block.content_hash not in self._cached_blocks: + # No cached content hash => Set this block as cached. + # Note that this block cannot be marked as computed yet + # because other sequences in the same batch cannot reuse + # this block. + self._cached_blocks[block.content_hash] = block.block_id + # Mark this block as touched so that it can be marked as + # computed after the entire batch of sequences are scheduled. + self._touched_blocks.add(block.block_id) + return block.block_id + + # Reuse the cached content hash + self._decr_refcount_hashless_block(block) + block.block_id = self._cached_blocks[block.content_hash] + + # Increment refcount of the cached block and (possibly) restore + # it from the evictor. + # Note that in this case, the block is marked as computed + self._incr_refcount_cached_block(block) + + return block.block_id + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if self._block_tracker[block_id].active: + self._block_tracker[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + # Mark all touched blocks as computed. + for block_id in self._touched_blocks: + self._block_tracker[block_id].computed = True + self._touched_blocks.clear() + + def _track_block_id(self, block_id: Optional[BlockId], + computed: bool) -> None: + assert block_id is not None + self._block_tracker[block_id].enable() + self._block_tracker[block_id].computed = computed + + def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_tracker[block_id].disable() + + def block_is_computed(self, block_id: int) -> bool: + if self._block_tracker[block_id].active: + return self._block_tracker[block_id].computed + else: + return block_id in self.evictor + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Return the block ids that are common for a given sequence group. + + Only those blocks that are immutable and already be marked + compyted would be taken consideration. + """ + + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + + # It returns a list of int although type annotation says list of string. + if len(computed_seq_block_ids) == 1: + return computed_seq_block_ids[0] + + return commonprefix([ + ids for ids in computed_seq_block_ids # type: ignore + if ids + ]) + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + num_touched_blocks: int = 0 + for block in blocks: + # If the block has a match in the cache and the cached + # block is not referenced, then we still count it as a + # touched block + if block.is_full and (not self.is_block_cached(block) or \ + (block.content_hash is not None and \ + self._cached_blocks[block.content_hash] in \ + self.evictor)): + num_touched_blocks += 1 + return num_touched_blocks + + def swap_out(self, blocks: List[Block]) -> None: + """Execute the swap out actions. Basically just free the + given blocks. + + Args: + blocks: List of blocks to be swapped out. + """ + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + """Execute the swap in actions. Change the block id from + old allocator to current allocator for each block to finish + the block table update. + + Args: + blocks: List of blocks to be swapped in. + """ + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, + token_ids=block.token_ids, + extra_hash=block.extra_hash) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block, extra_hash=block.extra_hash) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + """ + Given a list of block hashes, return the prefix of the block hashes that + are all cached. + + Since a block's block hash includes the hashes of all previous blocks, + and we only allocate/deallocate blocks in the entire sequence, so if a + block is cached, then all previous blocks are also cached. With this + property, we can use binary search to find the prefix of cached blocks. + + Args: + block_hashes (List[int]): The list of block hashes. + + Returns: + List[int]: The prefix of the `block_hashes` that are cached. + """ + + def _block_is_cached(block_hash: PrefixHash) -> bool: + if block_hash not in self._cached_blocks: + return False + + cached_block_id = self._cached_blocks[block_hash] + # We only consider the blocks that are marked as computed. + return self.block_is_computed(cached_block_id) + + def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: + + # python <= 3.10 don't have the key argument + if sys.version_info < (3, 10): + a = [key(e) for e in a] + return bisect_left(a, x) + else: + return bisect_left(a, x, key=key) + + # Look for the first block that's not cached, and returns the prefix + # i.e. blocks that are cached. + idx = _bisect_left(block_hashes, + True, + key=lambda x: not _block_is_cached(x)) + return block_hashes[:idx] + + +class PrefixCachingBlock(Block): + """A block implementation that supports prefix caching. + + The PrefixCachingBlock class represents a block of token IDs with prefix + caching capabilities. It wraps a NaiveBlock internally and provides + additional functionality for content hashing and promoting immutable blocks + with the prefix caching allocator. + + Args: + prev_block (Optional[PrefixCachingBlock]): The previous block in the + sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The prefix + caching block allocator associated with this block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None. + extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + def __init__( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + extra_hash: Optional[int] = None, + ): + assert isinstance(allocator, PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator. Got instead allocator = {}".format( + allocator)) + assert_prefix_caching_block_or_none(prev_block) + + self._prev_block = prev_block + self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: int = 0 + self._allocator = allocator + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed + self._extra_hash = extra_hash + + # On the first time, we create the block object, and next we only + # reinitialize it + if hasattr(self, "_block"): + self._block.__init__( # type: ignore[has-type] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + else: + self._block = NaiveBlock(prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + + self._update_num_tokens_total() + + def _update_num_tokens_total(self): + """Incrementally computes the number of tokens that there is + till the current block (included) + """ + res = 0 + + # Add all previous blocks + if self._prev_block is not None: + res += self._prev_block.num_tokens_total + + # Add current block + res += len(self.token_ids) + + self._cached_num_tokens_total = res + + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and registers the block as + immutable if the block becomes full. + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + # Ensure this is mutable block (not promoted) + assert self.content_hash is None + assert not self.computed + + if len(token_ids) == 0: + return + + # Ensure there are input tokens + assert token_ids, "Got token_ids = {}".format(token_ids) + + # Naive block handles CoW. + self._block.append_token_ids(token_ids) + self._update_num_tokens_total() + + # If the content hash is present, then the block can be made immutable. + # Register ourselves with the allocator, potentially replacing the + # physical block index. + if self.content_hash is not None: + self.block_id = self._allocator.promote_to_immutable_block(self) + + @property + def block_id(self) -> Optional[int]: + return self._block.block_id + + @block_id.setter + def block_id(self, value) -> None: + self._block.block_id = value + + @property + def is_full(self) -> bool: + return self._block.is_full + + @property + def num_empty_slots(self) -> int: + return self._block.num_empty_slots + + @property + def num_tokens_total(self) -> int: + return self._cached_num_tokens_total + + @property + def block_size(self) -> int: + return self._block.block_size + + @property + def token_ids(self) -> List[int]: + return self._block.token_ids + + @property + def prev_block(self) -> Optional[Block]: + return self._prev_block + + @property + def extra_hash(self) -> Optional[int]: + return self._extra_hash + + @property + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined. + + For the content-based hash to be defined, the current block must be + full. + """ + # If the hash is already computed, return it. + if self._cached_content_hash is not None: + return self._cached_content_hash + + # We cannot compute a hash for the current block because it is not full. + if not self.is_full: + return None + + is_first_block = self._prev_block is None + prev_block_hash = ( + self._none_hash if is_first_block else + self._prev_block.content_hash # type: ignore + ) + + # Previous block exists but does not yet have a hash. + # Return no hash in this case. + if prev_block_hash == self._none_hash and not is_first_block: + return None + + self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block, + prev_block_hash, + cur_block_token_ids=self.token_ids, + extra_hash=self._extra_hash) + return self._cached_content_hash + + @classmethod + def hash_block_tokens(cls, + is_first_block: bool, + prev_block_hash: Optional[int], + cur_block_token_ids: List[int], + extra_hash: Optional[int] = None) -> int: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. + + Parameters: + - is_first_block (bool): A flag indicating if the block is the first in + the sequence. + - prev_block_hash (Optional[int]): The hash of the previous block. None + if this is the first block. + - cur_block_token_ids (List[int]): A list of token ids in the current + block. The current block is assumed to be full. + - extra_hash (Optional[int]): The hash value of additional factors + such as adapters that influence the block, apart from the token_ids. + + Returns: + - int: The computed hash value for the block. + """ + if is_first_block and prev_block_hash is None: + prev_block_hash = cls._none_hash + return hash((is_first_block, prev_block_hash, *cur_block_token_ids, + extra_hash)) + + +class ComputedBlocksTracker: + """ + Tracks the computed blocks for each sequence. + + Internally, it maintains a map from sequence id to the list of block hashes + for the sequence. We cache the hashes of the full blocks for each sequence, + and make sure the hash is calculated in the same way as the allocator. + When a sequence is being decoded, we also update the sequence's hash + accordingly and incrementally. + + From the sequence hash, with prefix caching enabled, we could also calculate + the number of cached tokens for the sequence by looking up the number of + cached block hashes in the allocator. + """ + + # Note that we use 'None' as a string here instead of None because + # as of Python 3.12, hash(None) returns a constant predictable value. + # This could possibly make it easier to find and exploit hash + # collisions. 'None' as a string will be hashed differently per process, + # but consistently within the same process. This is the same as the + # behavior of None prior to Python 3.12. + _none_hash: int = hash('None') + + def __init__( + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, + ): + self._allocator = allocator + self._block_size = block_size + self._enable_caching = enable_caching + + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + # Note a block that's not full will not have its hash calculated and + # recorded. + self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} + + # A map from seq_id to the number of tokens that are cached for the + # sequence. + # We need this so that a sequence in continuous prefill doesn't + # accidentally see its cached token count change. See comments in + # `get_num_cached_tokens` for more details. + self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + + def _update_seq_hashes(self, seq: Sequence) -> None: + """Incrementally update the sequence's block hashes and record them.""" + assert self._enable_caching + + block_hashes_recorded = self._seq_id_to_blocks_hashes.get( + seq.seq_id, []) + cur_num_blocks_recorded = len(block_hashes_recorded) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only appended other than recomputation. When the sequence is " + "recomputed, we should have removed the info of the old blocks.") + # Update the computed block hashes for the sequence. Since only full + # blocks are considered as "computed", we take floor here. + num_computed_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else + block_hashes_recorded[-1]) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_computed_blocks): + assert len(token_ids) >= (i + 1) * self._block_size + block_token_ids = token_ids[i * self._block_size:(i + 1) * + self._block_size] + + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + + # This has to be kept in sync with the allocator's hash + # calculation. + block_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block=prev_block_hash == self._none_hash, + prev_block_hash=prev_block_hash, + cur_block_token_ids=block_token_ids, + extra_hash=extra_hash, + ) + block_hashes_recorded.append(block_hash) + prev_block_hash = block_hash + + self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded + + def get_num_cached_tokens(self, seq: Sequence) -> int: + if not self._enable_caching: + return 0 + + # We always try to update the sequence hashes on the fly. + # This is to ensure that we don't miss any cached tokens for the + # sequence during decode. + # This routine should only update hash for any new blocks too. + self._update_seq_hashes(seq) + + num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( + seq.seq_id, None) + + # TODO(rickyx): This hack could be removed once we mark blocks as + # computed correctly with chunked prefills. + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] + + # This is O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes)) + num_cached_tokens = num_cached_blocks * self._block_size + self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens + return num_cached_tokens + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking the sequence.""" + if not self._enable_caching: + return + assert seq_id in self._seq_id_to_blocks_hashes + del self._seq_id_to_blocks_hashes[seq_id] + + assert seq_id in self._seq_id_to_num_tokens_computed + del self._seq_id_to_num_tokens_computed[seq_id] + + +class LastAccessBlocksTracker: + """Manages the last access time of the tracked sequences, in order to allow + an efficient update of allocator's block last access times + """ + + def __init__(self, allocator): + self._allocator = allocator + self._seq_last_access: Dict[int, Optional[float]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._seq_last_access + self._seq_last_access[seq_id] = None + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._seq_last_access + del self._seq_last_access[seq_id] + + def update_last_access(self, seq_id: int, time: float) -> None: + assert seq_id in self._seq_last_access + self._seq_last_access[seq_id] = time + + def update_seq_blocks_last_access(self, seq_id: int, + block_ids: List[int]) -> None: + assert seq_id in self._seq_last_access + + ts = self._seq_last_access[seq_id] + + if ts is None: + # No last access was recorded, no need to update. + return + + self._allocator.mark_blocks_as_accessed(block_ids, ts) + + +def assert_prefix_caching_block_or_none(block: Optional[Block]): + if block is None: + return + assert isinstance(block, + PrefixCachingBlock), "Got block = {}".format(block) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000..e933c6e --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Block manager utils.""" +from vllm.sequence import SequenceGroup +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if block_mgr.max_block_sliding_window is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py new file mode 100644 index 0000000..4ec5a77 --- /dev/null +++ b/vllm/core/block_manager.py @@ -0,0 +1,525 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A block manager that manages token blocks.""" +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.core.block.block_table import BlockTable +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + LastAccessBlocksTracker) +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +SeqId = int +EncoderSeqId = str + + +class SelfAttnBlockSpaceManager(BlockSpaceManager): + """BlockSpaceManager which manages the allocation of KV cache. + + It owns responsibility for allocation, swapping, allocating memory for + autoregressively-generated tokens, and other advanced features such as + prefix caching, forking/copy-on-write, and sliding-window memory allocation. + + This class implements the design described in + https://github.com/vllm-project/vllm/pull/3492. + + Lookahead slots + The block manager has the notion of a "lookahead slot". These are slots + in the KV cache that are allocated for a sequence. Unlike the other + allocated slots, the content of these slots is undefined -- the worker + may use the memory allocations in any way. + + In practice, a worker could use these lookahead slots to run multiple + forward passes for a single scheduler invocation. Each successive + forward pass would write KV activations to the corresponding lookahead + slot. This allows low inter-token latency use-cases, where the overhead + of continuous batching scheduling is amortized over >1 generated tokens. + + Speculative decoding uses lookahead slots to store KV activations of + proposal tokens. + + See https://github.com/vllm-project/vllm/pull/3250 for more information + on lookahead scheduling. + + Args: + block_size (int): The size of each memory block. + num_gpu_blocks (int): The number of memory blocks allocated on GPU. + num_cpu_blocks (int): The number of memory blocks allocated on CPU. + watermark (float, optional): The threshold used for memory swapping. + Defaults to 0.01. + sliding_window (Optional[int], optional): The size of the sliding + window. Defaults to None. + enable_caching (bool, optional): Flag indicating whether caching is + enabled. Defaults to False. + """ + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 + + self.watermark = watermark + assert watermark >= 0.0 + + self.enable_caching = enable_caching + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + self.block_allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching" if enable_caching else "naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} + + self._computed_blocks_tracker = ComputedBlocksTracker( + self.block_allocator, self.block_size, self.enable_caching) + self._last_access_blocks_tracker = LastAccessBlocksTracker( + self.block_allocator) + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = BlockTable.get_num_required_blocks( + seq.get_token_ids(), + block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, + ) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + num_required_blocks += BlockTable.get_num_required_blocks( + encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + + if self.max_block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.max_block_sliding_window) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + device=Device.GPU) + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks + < self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + if seq.get_token_ids(): + # NOTE: If there are any factors affecting the block besides + # token_ids, they should be added as input to extra_hash. + extra_hash = seq.extra_hash() + + # Add blocks to the block table only if the sequence is non empty. + block_table.allocate(token_ids=seq.get_token_ids(), + extra_hash=extra_hash) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert not (set(seq.seq_id for seq in waiting_seqs) + & self.block_tables.keys()), "block table already exists" + + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = waiting_seqs[0] + block_table: BlockTable = self._allocate_sequence(seq) + self.block_tables[seq.seq_id] = block_table + + # Track seq + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Assign the block table for each sequence. + for seq in waiting_seqs[1:]: + self.block_tables[seq.seq_id] = block_table.fork() + + # Track seq + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + block_table = self._allocate_sequence(encoder_seq) + self.cross_block_tables[request_id] = block_table + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + """Determine if there is enough space in the GPU KV cache to continue + generation of the specified sequence group. + + We use a worst-case heuristic: assume each touched block will require a + new allocation (either via CoW or new block). We can append slots if the + number of touched blocks is less than the number of free blocks. + + "Lookahead slots" are slots that are allocated in addition to the slots + for known tokens. The contents of the lookahead slots are not defined. + This is used by speculative decoding when speculating future tokens. + """ + + num_touched_blocks = 0 + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + Device.GPU) + return num_touched_blocks <= num_free_gpu_blocks + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + + block_table = self.block_tables[seq.seq_id] + + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), + extra_hash=seq.extra_hash(), + ) + # Return any new copy-on-writes. + new_cows = self.block_allocator.clear_copy_on_writes() + return new_cows + + def free(self, seq: Sequence) -> None: + seq_id = seq.seq_id + + if seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + + # Update seq block ids with the latest access time + self._last_access_blocks_tracker.update_seq_blocks_last_access( + seq_id, self.block_tables[seq.seq_id].physical_block_ids) + + # Untrack seq + self._last_access_blocks_tracker.remove_seq(seq_id) + self._computed_blocks_tracker.remove_seq(seq_id) + + # Free table/blocks + self.block_tables[seq_id].free() + del self.block_tables[seq_id] + + def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: + seq_id = seq.seq_id + self._computed_blocks_tracker.remove_seq(seq_id) + + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + + def get_block_table(self, seq: Sequence) -> List[int]: + block_ids = self.block_tables[seq.seq_id].physical_block_ids + return block_ids # type: ignore + + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + if self.enable_caching: + # Record the latest access time for the sequence. The actual update + # of the block ids is deferred to the sequence free(..) call, since + # only during freeing of block ids, the blocks are actually added to + # the evictor (which is when the most updated time is required) + # (This avoids expensive calls to mark_blocks_as_accessed(..)) + self._last_access_blocks_tracker.update_last_access( + seq.seq_id, now) + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + # If prefix caching is enabled, mark immutable blocks as computed + # right after they have been scheduled (for prefill). This assumes + # the scheduler is synchronous so blocks are actually computed when + # scheduling the next batch. + self.block_allocator.mark_blocks_as_computed([]) + + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + """Determine which blocks for which we skip prefill. + + With prefix caching we can skip prefill for previously-generated blocks. + Currently, the attention implementation only supports skipping cached + blocks if they are a contiguous prefix of cached blocks. + + This method determines which blocks can be safely skipped for all + sequences in the sequence group. + """ + computed_seq_block_ids = [] + for seq in seqs: + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_cached_tokens(seq)) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] + computed_seq_block_ids.append(computed_block_ids) + + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. + return self.block_allocator.get_common_computed_block_ids( + computed_seq_block_ids) # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.fork() + + # Track child seq + self._last_access_blocks_tracker.add_seq(child_seq.seq_id) + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + """Returns the AllocStatus for the given sequence_group + with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for the given sequence group. + """ + return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, + num_lookahead_slots) + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from CPU to GPU) generated by + swapping in the given seq_group with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from CPU + to GPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.CPU, + dst_device=Device.GPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id): + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id) + for cpu_block_id, gpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + """Returns whether we can swap out the given sequence_group + with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap out. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + bool: Whether it's possible to swap out current sequence group. + """ + alloc_status = self._can_swap(seq_group, Device.CPU, + SequenceStatus.RUNNING) + return alloc_status == AllocStatus.OK + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from GPU to CPU) generated by + swapping out the given sequence_group with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap out. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from + GPU to CPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.GPU, + dst_device=Device.CPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id): + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id) + for gpu_block_id, cpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def get_num_free_gpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.GPU) + + def get_num_free_cpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.CPU) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_allocator.reset_prefix_cache(device) + + def _can_swap(self, + seq_group: SequenceGroup, + device: Device, + status: SequenceStatus, + num_lookahead_slots: int = 0) -> AllocStatus: + """Returns the AllocStatus for swapping in/out the given sequence_group + on to the 'device'. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in/out. + device (Device): device to swap the 'seq_group' on. + status (SequenceStatus): The status of sequence which is needed + for action. RUNNING for swap out and SWAPPED for swap in + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for swapping in/out the given + sequence_group on to the 'device'. + """ + # First determine the number of blocks that will be touched by this + # swap. Then verify if there are available blocks in the device + # to perform the swap. + num_blocks_touched = 0 + blocks: List[Block] = [] + for seq in seq_group.get_seqs(status=status): + block_table = self.block_tables[seq.seq_id] + if block_table.blocks is not None: + # Compute the number blocks to touch for the tokens to be + # appended. This does NOT include the full blocks that need + # to be touched for the swap. + num_blocks_touched += \ + block_table.get_num_blocks_touched_by_append_slots( + block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots) + blocks.extend(block_table.blocks) + # Compute the number of full blocks to touch and add it to the + # existing count of blocks to touch. + num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( + blocks, device=device) + + watermark_blocks = 0 + if device == Device.GPU: + watermark_blocks = self.watermark_blocks + + if self.block_allocator.get_num_total_blocks( + device) < num_blocks_touched: + return AllocStatus.NEVER + elif self.block_allocator.get_num_free_blocks( + device) - num_blocks_touched >= watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def get_num_cached_tokens(self, seq: Sequence) -> int: + """Get the number of tokens in blocks that are already computed and + cached in the block manager for the sequence. + """ + return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py new file mode 100644 index 0000000..7ec4768 --- /dev/null +++ b/vllm/core/evictor.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +import heapq +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed Blocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: float): + """Update corresponding block's access time in metadata""" + pass + + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + + @property + @abstractmethod + def num_blocks(self) -> int: + pass + + +class BlockMetaData: + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the Block. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + # CLEANUP_THRESHOLD determines the maximum allowable size of the priority + # queue relative to the free table size. When this threshold is exceeded, + # a cleanup operation is triggered to reduce memory usage. + CLEANUP_THRESHOLD = 50 + + def __init__(self): + self.free_table: Dict[int, BlockMetaData] = {} + self.priority_queue = [] + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + while self.priority_queue: + # We do not remove outdated entries from the priority queue at the + # time of updating the last_accessed timestamp. Instead, outdated + # entries are filtered out here during eviction. Outdated entries + # would either not in the free table, or have older last accessed + # time. + last_accessed, _, block_id, content_hash = heapq.heappop( + self.priority_queue) + if (block_id in self.free_table and + self.free_table[block_id].last_accessed == last_accessed): + self.free_table.pop(block_id) + return block_id, content_hash + + raise ValueError("No usable cache memory left") + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + heapq.heappush( + self.priority_queue, + (last_accessed, -num_hashed_tokens, block_id, content_hash)) + self._cleanup_if_necessary() + + def update(self, block_id: int, last_accessed: float): + self.free_table[block_id].last_accessed = last_accessed + + def _cleanup_if_necessary(self): + if len(self.priority_queue) > LRUEvictor.CLEANUP_THRESHOLD * len( + self.free_table): + self._cleanup() + + def _cleanup(self): + new_priority_queue: List[Tuple[float, int, int, int]] = [] + + for block_id, block in self.free_table.items(): + new_priority_queue.append( + (block.last_accessed, -block.num_hashed_tokens, block_id, + block.content_hash)) + heapq.heapify(new_priority_queue) + + self.priority_queue = new_priority_queue + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py new file mode 100644 index 0000000..69b9169 --- /dev/null +++ b/vllm/core/interfaces.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from abc import ABC, abstractmethod +from typing import List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager(ABC): + + @staticmethod + def get_block_space_manager_class(version: str): + version = version.lower() + + if version == "selfattn": + from vllm.core.block_manager import SelfAttnBlockSpaceManager + return SelfAttnBlockSpaceManager + + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager + + raise ValueError(f"Unknown version {version=}") + + @abstractmethod + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + pass + + @abstractmethod + def allocate(self, seq_group: SequenceGroup) -> None: + pass + + @abstractmethod + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + pass + + @abstractmethod + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + @abstractmethod + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + pass + + @abstractmethod + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def free(self, seq: Sequence) -> None: + pass + + @abstractmethod + def get_block_table(self, seq: Sequence) -> List[int]: + pass + + @abstractmethod + def get_num_free_gpu_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_cpu_blocks(self) -> int: + pass + + @abstractmethod + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + pass + + @abstractmethod + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + @abstractmethod + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for specified or all devices.""" + pass + + @abstractmethod + def get_num_cached_tokens(self, seq: Sequence) -> int: + pass + + @abstractmethod + def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: + pass \ No newline at end of file diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py new file mode 100644 index 0000000..6795159 --- /dev/null +++ b/vllm/core/placeholder_block_space_manager.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List, Optional, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: pooling models or attention-free models like Mamba. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return [] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: List[Sequence]) -> List[int]: + return [] + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return True + + def get_num_cached_tokens(self, seq: Sequence) -> int: + return 0 + + def remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: + return diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py new file mode 100644 index 0000000..0ef0396 --- /dev/null +++ b/vllm/core/scheduler.py @@ -0,0 +1,2126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +import os +import random +import time +from collections import deque +from dataclasses import dataclass, field +from typing import Callable, Deque, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupMetadataDelta, SequenceStage, + SequenceStatus) +from vllm.utils import Device, PyObjectCache + +logger = init_logger(__name__) + +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +@dataclass +class SchedulingBudget: + """The available slots for scheduling. + + TODO(sang): Right now, the budget is request_id-aware meaning it can ignore + budget update from the same request_id. It is because in normal scheduling + path, we update RUNNING num_seqs ahead of time, meaning it could be + updated more than once when scheduling RUNNING requests. Since this won't + happen if we only have chunked prefill scheduling, we can remove this + feature from the API when chunked prefill is enabled by default. + """ + + token_budget: int + max_num_seqs: int + _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of cached tokens in the batch. + _num_cached_tokens: int = 0 + # Number of actual non-cached tokens in the batch. + _num_batched_tokens: int = 0 + _num_curr_seqs: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + # We allow num_new_tokens to be 0 when the entire sequence has + # been cached. + assert num_new_tokens >= 0 + assert num_new_seqs != 0 + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + def remaining_token_budget(self): + return self.token_budget - self.num_batched_tokens + + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): + if req_id in self._request_ids_num_batched_tokens: + return + assert num_cached_tokens >= 0 + assert num_batched_tokens >= 0 + + self._request_ids_num_batched_tokens.add(req_id) + self._num_batched_tokens += num_batched_tokens + self._num_cached_tokens += num_cached_tokens + + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + self._request_ids_num_batched_tokens.remove(req_id) + self._num_batched_tokens -= num_batched_tokens + + def add_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + return + + self._request_ids_num_curr_seqs.add(req_id) + self._num_curr_seqs += num_curr_seqs + + def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + self._request_ids_num_curr_seqs.remove(req_id) + self._num_curr_seqs -= num_curr_seqs + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_seqs(self): + return self._num_curr_seqs + + @property + def num_cached_tokens(self): + return self._num_cached_tokens + + +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + +@dataclass +class SchedulerOutputs: + """The scheduling decision made from a scheduler.""" + + # Scheduled sequence groups. + scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int + preempted: int + + def __post_init__(self): + # Swap in and swap out should never happen at the same time. + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + + self.num_loras: int = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + def _sort_by_lora_ids(self): + assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups) + + def key_fn(group: ScheduledSequenceGroup): + key = (group.seq_group.lora_int_id, group.seq_group.request_id) + if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups): + # Sort sequence groups so that all prefills come before all + # decodes as required by chunked prefill. + return (not group.seq_group.is_prefill(), *key) + return key + + self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, + key=key_fn) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } + + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + + +@dataclass +class SchedulerRunningOutputs: + """The requests that are scheduled from a running queue. + + Could contain prefill (prefill that's chunked) or decodes. If there's not + enough memory, it can be preempted (for recompute) or swapped out. + """ + + # Selected sequences that are running and in a decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are running and in a prefill phase. + # I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + # Optimization for fast-access to seq_group lists + decode_seq_groups_list: List[SequenceGroup] + prefill_seq_groups_list: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerRunningOutputs": + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + decode_seq_groups_list=[], + prefill_seq_groups_list=[], + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """The requests that are scheduled from a swap queue. + + Could contain prefill (prefill that's chunked) or decodes. + """ + + # Selected sequences that are going to be swapped in and is in a + # decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are going to be swapped in and in a prefill + # phase. I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + blocks_to_swap_in=[], + blocks_to_copy=[], + num_lookahead_slots=0, + infeasible_seq_groups=[], + ) + + +@dataclass +class SchedulerPrefillOutputs: + """The requests that are scheduled from a waiting queue. + + Could contain a fresh prefill requests or preempted requests that need + to be recomputed from scratch. + """ + + # Selected sequences for prefill. + seq_groups: List[ScheduledSequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + +def seq_group_metadata_builder(): + return SequenceGroupMetadata(request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}) + + +def scheduler_running_outputs_builder(): + return SchedulerRunningOutputs(decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[]) + + +def scheduled_seq_group_builder(): + return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + + +@dataclass +class PartialPrefillMetadata: + """Holds information about the partial prefills that are currently running + during a single iteration of the Scheduler. + When chunked prefill is enabled, we allow a certain number of seqs to be + partially prefilled during each iteration. Having multiple partial prefills + in flight allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks the queue for + too many iterations. + The number of long prefill requests is limited so that smaller + requests may jump the queue in front of them and get to the decode + phase faster. + """ + + # A minimum bound on the total number of prefills to be scheduled during + # this iteration + schedulable_prefills: int + + # The number of long prefill requests currently running + long_prefills: int + + scheduler_config: SchedulerConfig + + def can_schedule(self, seq_group: SequenceGroup) -> bool: + """When concurrent partial prefills are enabled, + we limit the number of long requests and only accept + shorter requests from the queue while running them + concurrently""" + return not (seq_group.first_seq.get_num_new_tokens() + > self.scheduler_config.long_prefill_token_threshold + and self.long_prefills + >= self.scheduler_config.max_long_partial_prefills + and self.scheduler_config.max_num_partial_prefills > 1) + + def maybe_increment_partial_prefills(self, + seq_group: SequenceGroup) -> None: + # When a new prefill is scheduled, we need to know if it is a + # long request + if (seq_group.first_seq.get_num_new_tokens() + > self.scheduler_config.long_prefill_token_threshold): + self.long_prefills += 1 + + @classmethod + def from_queues( + cls, + running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig, + ) -> "PartialPrefillMetadata": + """Create a PartialPrefillMetadata object from the current state of + the scheduler's queues. + This accounts for the currently running prefill requests, and peeks into + the waiting queue to see if there are more prefills to potentially be + scheduled during this iteration.""" + prefills = 0 + long_prefills = 0 + + waiting_long_prefills = 0 + + for sg in running: + if sg.first_seq.data.stage == SequenceStage.PREFILL: + prefills += 1 + if (sg.first_seq.get_num_new_tokens() + > scheduler_config.long_prefill_token_threshold): + long_prefills += 1 + + for sg in waiting: + # Don't bother looping through the rest of the queue if we know + # there are already at + # least max_partial_prefills requests to fill + if prefills >= scheduler_config.max_num_partial_prefills: + break + + # Don't count long requests from the waiting queue if we aren't + # going to schedule them anyway + if (sg.first_seq.get_num_new_tokens() + > scheduler_config.long_prefill_token_threshold): + if (long_prefills + waiting_long_prefills + >= scheduler_config.max_long_partial_prefills): + continue + waiting_long_prefills += 1 + prefills += 1 + + # NB: long_prefills and waiting_long_prefills are tracked separately. + # We don't account for the waiting requests here because we need to use + # this metadata to track how many have actually been scheduled. + return PartialPrefillMetadata( + schedulable_prefills=min( + prefills, scheduler_config.max_num_partial_prefills), + long_prefills=long_prefills, + scheduler_config=scheduler_config, + ) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, + output_proc_callback: Optional[Callable] = None, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + version = "selfattn" + if (self.scheduler_config.runner_type == "pooling" + or self.cache_config.is_attention_free): + version = "placeholder" + + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version) + + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= pipeline_parallel_size + + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching, + ) + + # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. + self.waiting: Deque[SequenceGroup] = deque() + # Sequence groups in the RUNNING state. + # Contain decode requests. + self.running: Deque[SequenceGroup] = deque() + # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. + self.swapped: Deque[SequenceGroup] = deque() + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. + # This is used to evict the finished requests from the Mamba cache. + self._finished_requests_ids: List[str] = list() + # Time at previous scheduling step + self.prev_time = 0.0 + # Did we schedule a prompt at previous step? + self.prev_prompt = False + # Latency of the last prompt step + self.last_prompt_latency = 0.0 + # preemption mode, RECOMPUTE or SWAP + self.user_specified_preemption_mode = scheduler_config.preemption_mode + + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + self.num_cumulative_preemption: int = 0 + + # Used to cache python objects + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback = output_proc_callback + self.use_async_output_proc = self.output_proc_callback is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.scheduler_config.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[0] = ( + scheduler_config.max_num_batched_tokens) + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): + self.partial_prefill_budget_lookup_list[i] = ( + scheduler_config.max_num_batched_tokens // i) + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters + + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + + def abort_seq_group( + self, + request_id: Union[str, Iterable[str]], + seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, + ) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + seq_id_to_seq_group: helper for groups with n>1 + """ + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + seq_id_to_seq_group = seq_id_to_seq_group or {} + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + # When n>1, seq_group.request_id looks like + # foo_parallel_sample_0, while request_ids is just foo, and we + # should resolve it as real_request_id to match. + if seq_group.request_id in seq_id_to_seq_group: + real_request_id = seq_id_to_seq_group[ + seq_group.request_id].group_id + else: + real_request_id = seq_group.request_id + if real_request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + # We can't remove real_request_id in request_ids here, + # because there may be other seq groups sharing the same + # real_request_id + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + # Remove the aborted request from the Mamba cache. + self._finished_requests_ids.append(aborted_group.request_id) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + if aborted_group.request_id in seq_id_to_seq_group: + del seq_id_to_seq_group[aborted_group.request_id] + + self._free_seq_group_cross_attn_blocks(aborted_group) + + def _free_seq_group_cross_attn_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + if seq_group.is_encoder_decoder(): + self.block_manager.free_cross(seq_group) + + def has_unfinished_seqs(self) -> bool: + return (len(self.waiting) != 0 or len(self.running) != 0 + or len(self.swapped) != 0) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.block_manager.reset_prefix_cache(device) + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def get_and_reset_finished_requests_ids(self) -> List[str]: + """Flushes the list of request ids of previously finished seq_groups.""" + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() + return finished_requests_ids + + def _schedule_running( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> SchedulerRunningOutputs: + """Schedule sequence groups that are running. + + Running queue should include decode and chunked prefill requests. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running + + Returns: + SchedulerRunningOutputs. + """ + ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ + self.cache_id].get_object() + ret.blocks_to_swap_out.clear() + ret.blocks_to_copy.clear() + ret.decode_seq_groups.clear() + ret.prefill_seq_groups.clear() + ret.preempted.clear() + ret.swapped_out.clear() + + ret.num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking) + + ret.decode_seq_groups_list.clear() + ret.prefill_seq_groups_list.clear() + + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy + + decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups + preempted: List[SequenceGroup] = ret.preempted + swapped_out: List[SequenceGroup] = ret.swapped_out + + running_queue = self.running + assert len(self._async_stopped) == 0 + while running_queue: + seq_group = running_queue[0] + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) + + num_running_tokens = num_uncached_new_tokens + if num_running_tokens == 0: + # No budget => Stop + break + + running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if (self.use_async_output_proc and seq_group.seqs[0].get_len() + > self.scheduler_config.max_model_len): + self._async_stopped.append(seq_group) + continue + + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. + while not self._can_append_slots(seq_group, enable_chunking): + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) + + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): + curr_loras.remove(seq_group.lora_int_id) + + # Determine victim sequence + cont_loop = True + if running_queue: + # Preempt the lowest-priority sequence group. + victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) + + if not cont_loop: + break + else: + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + + scheduled_seq_group: ScheduledSequenceGroup = ( + self._scheduled_seq_group_cache[ + self.cache_id].get_object()) + scheduled_seq_group.seq_group = seq_group + if is_prefill: + scheduled_seq_group.token_chunk_size = num_running_tokens + prefill_seq_groups.append(scheduled_seq_group) + ret.prefill_seq_groups_list.append(seq_group) + else: + scheduled_seq_group.token_chunk_size = 1 + decode_seq_groups.append(scheduled_seq_group) + ret.decode_seq_groups_list.append(seq_group) + + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.add(seq_group.lora_int_id) + + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() + + return ret + + def _schedule_swapped( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerSwappedInOutputs: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + infeasible_seq_groups: List[SequenceGroup] = [] + + swapped_queue = self.swapped + + leftover_swapped: Deque[SequenceGroup] = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + is_prefill = seq_group.is_prefill() + alloc_status = self.block_manager.can_swap_in( + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) + if alloc_status == AllocStatus.LATER: + break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id, + ) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) + + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.SWAPPED) + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) + else: + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + swapped_queue.extendleft(leftover_swapped) + + return SchedulerSwappedInOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking), + infeasible_seq_groups=infeasible_seq_groups, + ) + + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens_uncached, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget) + + # Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + # Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): + break + + # Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget)) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + # Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + # Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.WAITING) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + + def _schedule_prefills( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> SchedulerPrefillOutputs: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running + + Returns: + SchedulerPrefillOutputs. + """ + if budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] + using_prompt_embeds: bool = False + + waiting_queue = self.waiting + + leftover_waiting_sequences: Deque[SequenceGroup] = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + if (partial_prefill_metadata is not None + and not partial_prefill_metadata.can_schedule(seq_group)): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata, + )) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens + + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", + num_new_tokens, + prompt_limit, + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.FINISHED_IGNORED) + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) + if can_allocate == AllocStatus.LATER: + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.WAITING) + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, + num_lookahead_slots, + ) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.FINISHED_IGNORED) + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + # We cannot mix sequence groups that use prompt embeds and + # those that do not. + if len(seq_groups) == 0: + using_prompt_embeds = seq_group.uses_prompt_embeds() + if using_prompt_embeds != seq_group.uses_prompt_embeds(): + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.WAITING) + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.WAITING) + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + if (budget.num_batched_tokens + >= self.scheduler_config.max_num_batched_tokens): + # We've reached the budget limit - since there might be + # continuous prefills in the running queue, we should break + # to avoid scheduling any new prefills. + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.WAITING) + break + + num_new_seqs = seq_group.get_max_num_running_seqs() + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): + self.remove_seq_from_computed_blocks_tracker( + seq_group, SequenceStatus.WAITING) + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group) + + if partial_prefill_metadata is not None: + partial_prefill_metadata.maybe_increment_partial_prefills( + seq_group) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking, + ) + + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking), + ) + + def _schedule_default(self) -> SchedulerOutputs: + """Schedule queued requests. + + The current policy is designed to optimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + # Make sure we include num running seqs before scheduling prefill, + # so that we don't schedule beyond max_num_seqs for prefill. + for seq_group in self.running: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = (set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None) + + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) + + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + + # Don't schedule decodes if prefills are scheduled. + # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running + # only contains decode requests, not chunked prefills. + if len(prefills.seq_groups) == 0: + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) + + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): + swapped_in = \ + self._schedule_swapped(budget, curr_loras) + + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + if len(prefills.seq_groups) > 0: + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + self.running.extend(running_scheduled.decode_seq_groups_list) + + if len(swapped_in.decode_seq_groups) > 0: + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + preempted = len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) + + # There should be no prefill from running queue because this policy + # doesn't allow chunked prefills. + assert len(running_scheduled.prefill_seq_groups) == 0 + assert len(swapped_in.prefill_seq_groups) == 0 + + # Merge lists + num_prefill_groups = len(prefills.seq_groups) + ignored_seq_groups_for_embeds = list[SequenceGroup]() + if num_prefill_groups > 0: + scheduled_seq_groups = prefills.seq_groups + scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + ignored_seq_groups_for_embeds.clear() + else: + scheduled_seq_groups = running_scheduled.decode_seq_groups + if len(scheduled_seq_groups) > 0: + using_prompt_embeds = scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + ignored_seq_groups_for_embeds.clear() + indices_ignored = list[int]() + for i, schedule_seq_group in enumerate(scheduled_seq_groups): + if using_prompt_embeds !=\ + schedule_seq_group.seq_group.uses_prompt_embeds(): + ignored_seq_groups_for_embeds.append( + schedule_seq_group.seq_group) + indices_ignored.append(i) + if len(ignored_seq_groups_for_embeds) > 0: + scheduled_seq_groups = [ + group for i, group in enumerate(scheduled_seq_groups) + if i not in indices_ignored + ] + else: + ignored_seq_groups_for_embeds.clear() + + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(ignored_seq_groups_for_embeds) + ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) + + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=preempted, + ) + + def _schedule_chunked_prefill(self) -> SchedulerOutputs: + """Schedule queued requests. + + Chunked prefill allows to chunk prefill requests, batch them together + with decode requests. This policy 1. schedule as many decoding requests + as possible. 2. schedule chunked prefill requests that are not + finished. 3. schedule swapped request. 4. schedule new prefill + requests. + + The policy can sustain the high GPU utilization because it can put + prefill and decodes requests to the same batch, while it improves + inter token latency because decodes requests don't need to be blocked + by prefill requests. + """ + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras: Set[int] = set() + + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + scheduler_config=self.scheduler_config, + ) + + # Decoding should be always scheduled first by fcfs. + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) + + # Schedule swapped out requests. + # If preemption happens, it means we don't have space for swap-in. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) + + assert (budget.num_batched_tokens + <= self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + + # Update new running requests. + # By default, vLLM scheduler prioritizes prefills. + # Once chunked prefill is enabled, + # the policy is changed to prioritize decode requests. + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + # Because multiple prefills may be running concurrently, we need to + # make sure that prefills which are scheduled to finish are listed + # before those that won't. This is so that on the next scheduling + # iteration when they have transitioned to the decode stage, they are + # properly prioritized over sequences that are still in the prefill + # stage. + self.running.extend( + self._order_finishing_prefills_first( + running_scheduled.prefill_seq_groups)) + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + # Put prefills first due to Attention backend ordering assumption. + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) + # If all prompts, then we set num_lookahead_slots to 0 + # this allows us to go through the `no_spec` path in + # `spec_decode_worker.py` + all_prefills = len(scheduled_seq_groups) == num_prefill_groups + num_lookahead_slots = (0 if + (all_prefills + and not self.scheduler_config.is_multi_step) + else running_scheduled.num_lookahead_slots) + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, + num_lookahead_slots=num_lookahead_slots, + running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), + ) + + def _order_finishing_prefills_first( + self, scheduled_prefill_seqs: List[ScheduledSequenceGroup] + ) -> List[SequenceGroup]: + """Returns a list of prefilling SequenceGroups where sequences that are + scheduled to finish prefilling are listed first""" + finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size + ] + not_finishing = [ + s.seq_group for s in scheduled_prefill_seqs + if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size + ] + return finishing + not_finishing + + def _schedule(self) -> SchedulerOutputs: + """Schedule queued requests.""" + if self.scheduler_config.chunked_prefill_enabled: + return self._schedule_chunked_prefill() + else: + return self._schedule_default() + + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + + is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking + + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + # async_output_proc is allowed only when we have a single sequence + # in the sequence group + no_single_seq = seq_group.sampling_params is None or ( + seq_group.sampling_params.n == 1) + return no_single_seq + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_start_time = time.perf_counter() + + scheduler_outputs: SchedulerOutputs = self._schedule() + now = time.time() + + if not self.cache_config.enable_prefix_caching: + common_computed_block_nums = [] + + allow_async_output_proc: bool = self.use_async_output_proc + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.maybe_set_first_scheduled_time(now) + + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + + # seq_id -> SequenceData + seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers + block_tables: Dict[int, List[int]] = {} + + if seq_group.is_encoder_decoder(): + # Encoder associated with SequenceGroup + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) + else: + encoder_seq_data = None + cross_block_table = None + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) + + if self.cache_config.enable_prefix_caching: + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) + + do_sample = True + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + num_computed_tokens + < seqs[0].data.get_len()): + do_sample = False + + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + token_type_ids=seq_group.token_type_ids, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=(seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups + > 0 else None), + multi_modal_placeholders=( + seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None), + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) + seq_group_metadata_list.append(seq_group_metadata) + + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution + # will crash the vLLM instance / will not retry. + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) + + self._seq_group_metadata_cache[self.next_cache_id].reset() + + scheduler_time = time.perf_counter() - scheduler_start_time + # Add this to scheduler time to all the sequences that are currently + # running. This will help estimate if the scheduler is a significant + # component in the e2e latency. + for seq_group in self.running: + if seq_group is not None and seq_group.metrics is not None: + if seq_group.metrics.scheduler_time is not None: + seq_group.metrics.scheduler_time += scheduler_time + else: + seq_group.metrics.scheduler_time = scheduler_time + + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" + self.block_manager.free(seq) + + def remove_seq_from_computed_blocks_tracker( + self, seq_group: SequenceGroup, + status: Optional[SequenceStatus]) -> None: + seqs = seq_group.get_seqs(status=status) + for seq in seqs: + self._remove_seq_from_computed_blocks_tracker(seq) + + def _remove_seq_from_computed_blocks_tracker(self, seq: Sequence) -> None: + """ + Free a sequence computed blocks tracker _seq_id_to_blocks_hashes + and _seq_id_to_num_tokens_computed. + """ + self.block_manager.remove_seq_from_computed_blocks_tracker(seq) + + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + def free_finished_seq_groups(self) -> None: + remaining: Deque[SequenceGroup] = deque() + for seq_group in self.running: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): + remaining.append(seq_group) + + self.running = remaining + + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slots( + self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False, + ) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. + enable_chunking (bool): True if chunked prefill is enabled. + """ + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking, + ) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + if len(cows) > 0: + blocks_to_copy.extend(cows) + + def _preempt(self, seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if self.user_specified_preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + + elif self.user_specified_preemption_mode == "swap": + preemption_mode = PreemptionMode.SWAP + else: + preemption_mode = PreemptionMode.RECOMPUTE + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", + seq_group.request_id, + preemption_mode, + self.num_cumulative_preemption + 1, + ) + self.num_cumulative_preemption += 1 + + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + return preemption_mode + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.free_seq(seq) + seq.reset_state_for_recompute() + self._free_seq_group_cross_attn_blocks(seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: List[Tuple[int, int]], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED + + def _passed_delay(self, now: float) -> bool: + if self.prev_prompt: + self.last_prompt_latency = now - self.prev_time + self.prev_time, self.prev_prompt = now, False + # Delay scheduling prompts to let waiting queue fill up + if self.scheduler_config.delay_factor > 0 and self.waiting: + earliest_arrival_time = min( + [e.metrics.arrival_time for e in self.waiting]) + passed_delay = ((now - earliest_arrival_time) + > (self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running) + else: + passed_delay = True + return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. + """ + if is_prefill: + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 + + return self.scheduler_config.num_lookahead_slots + + def _get_num_new_uncached_and_cached_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. + + The API could chunk the number of tokens to compute based on `budget` + if `enable_chunking` is True. If a sequence group has multiple + sequences (e.g., running beam search), it means it is in decoding + phase, so chunking doesn't happen. + + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + partial_prefill_metadata: information about the partial prefills + that are currently running + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). + """ + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + + seqs = seq_group.get_seqs(status=status) + # Compute the number of new uncached and cached tokens for + # each sequence. + for seq in seqs: + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + + if enable_chunking and len(seqs) == 1: + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + self.partial_prefill_budget_lookup_list, + partial_prefill_metadata, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + partial_prefill_budget_lookup_list: List[int], + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. + + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + return num_new_tokens + + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens + + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = ( + remaining_token_budget if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.schedulable_prefills]) + + if cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. + block_size = cache_config.block_size + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = ( + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) + + return num_new_tokens diff --git a/vllm/device_allocator/__init__.py b/vllm/device_allocator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py new file mode 100644 index 0000000..942e866 --- /dev/null +++ b/vllm/device_allocator/cumem.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# cumem-based pytorch pluggable allocator to implement sleep mode. +# other approaches tried but failed: +# - cuda-python package binding +# - custom libcuda driver ctypes wrapper +# both of them failed because of cuda context mismatch. +# not sure why, they are created from a different context. +# the only successful approach is to call cuda driver API in C. +import dataclasses +import gc +import os +from contextlib import contextmanager +from typing import Any, Callable, Optional, Union + +import torch + +from vllm.utils import is_pin_memory_available + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found_line = None + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found_line = line + break + if found_line is None: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = found_line.index("/") + path = found_line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +cumem_available = False +try: + from vllm.cumem_allocator import (init_module, python_create_and_map, + python_unmap_and_release) + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib_name = find_loaded_library("cumem_allocator") + libcudart = CudaRTLibrary() + cumem_available = True +except ModuleNotFoundError: + # rocm platform does not support cumem allocator + init_module = None + python_create_and_map = None + python_unmap_and_release = None + CudaRTLibrary = None + lib_name = None + libcudart = None + +# py_device, py_alignedSize, py_d_mem, py_p_memHandle +HandleType = tuple[int, int, int, int] + + +@dataclasses.dataclass +class AllocationData: + handle: HandleType + tag: str + cpu_backup_tensor: Optional[torch.Tensor] = None + + +def create_and_map(allocation_handle: HandleType) -> None: + python_create_and_map(*allocation_handle) + + +def unmap_and_release(allocation_handle: HandleType) -> None: + python_unmap_and_release(*allocation_handle) + + +def get_pluggable_allocator( + python_malloc_fn: Callable[[int], + int], python_free_func: Callable[[int, int], + None] +) -> torch.cuda.memory.CUDAPluggableAllocator: + init_module(python_malloc_fn, python_free_func) + new_alloc = torch.cuda.memory.CUDAPluggableAllocator( + lib_name, 'my_malloc', 'my_free') + return new_alloc + + +@contextmanager +def use_memory_pool_with_allocator( + python_malloc_fn: Callable[[int], int], + python_free_func: Callable[[int, int], None]) -> None: + new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) + mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) + with torch.cuda.memory.use_mem_pool(mem_pool): + yield mem_pool, new_alloc + + +class CuMemAllocator: + """ + A singleton class that manages a memory pool for CUDA tensors. + The memory in this pool can be offloaded or discarded when the + allocator sleeps. + + Inside the `use_memory_pool(tag)` context, all tensors created will + be allocated in the memory pool, and has the same tag as the + tag passed to the context. + + When we call `sleep`, all tensors with the specified tag will be + offloaded to CPU memory, and the rest of the tensors will be discarded. + When we call `wake_up`, all tensors that are previously offloaded + will be loaded back to GPU memory, and the rest of the tensors will + have empty memory. + + Why it needs to be a singleton? + When allocated tensors are garbage collected, PyTorch will call + the free callback, which will call the `python_free_callback` method. + The C-extension uses a global variable to store the function of an + instance of this class. If we create multiple instances of this class, + the global variable will be overwritten and the free callback will + not work as expected. + """ + instance: "CuMemAllocator" = None + default_tag: str = "default" + + @staticmethod + def get_instance() -> "CuMemAllocator": + """ + CuMemAllocator is a singleton class. + We cannot call the constructor directly. + Call this method to get the instance. + """ + assert cumem_available, "cumem allocator is not available" + if CuMemAllocator.instance is None: + CuMemAllocator.instance = CuMemAllocator() + return CuMemAllocator.instance + + def __init__(self): + conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, \ + ("Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates.") + + self.pointer_to_data: dict[int, AllocationData] = {} + self.current_tag: str = CuMemAllocator.default_tag + self.allocator_and_pools: dict[str, Any] = {} + + def python_malloc_callback(self, allocation_handle: HandleType) -> None: + """ + Internal method to store the allocation data + when memory is allocated in the memory pool.""" + py_d_mem = allocation_handle[2] + self.pointer_to_data[py_d_mem] = AllocationData( + allocation_handle, self.current_tag) + return + + def python_free_callback(self, ptr: int) -> HandleType: + """ + Internal method to look up the allocation data + when memory is freed in the memory pool.""" + data = self.pointer_to_data.pop(ptr) + if data.cpu_backup_tensor is not None: + data.cpu_backup_tensor = None + return data.handle + + def sleep( + self, + offload_tags: Optional[Union[tuple[str, ...], + str]] = None) -> None: + """ + Put the allocator in sleep mode. + All data in the memory allocation with the specified tag will be + offloaded to CPU memory, and others will be discarded. + + :param offload_tags: The tags of the memory allocation that will be + offloaded. The rest of the memory allocation will be discarded. + """ + if offload_tags is None: + # by default, allocated tensors are offloaded + # when the allocator sleeps + offload_tags = (CuMemAllocator.default_tag, ) + elif isinstance(offload_tags, str): + offload_tags = (offload_tags, ) + + assert isinstance(offload_tags, tuple) + + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + if data.tag in offload_tags: + size_in_bytes = handle[1] + cpu_backup_tensor = torch.empty( + size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=is_pin_memory_available()) + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) + data.cpu_backup_tensor = cpu_backup_tensor + unmap_and_release(handle) + + gc.collect() + torch.cuda.empty_cache() + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + """ + Wake up the allocator from sleep mode. + All data that is previously offloaded will be loaded back to GPU + memory, and the rest of the data will have empty memory. + + :param tags: The tags of the memory allocation that will be loaded + back to GPU memory. If None, all memory allocation will be loaded + back to GPU memory. + """ + for ptr, data in self.pointer_to_data.items(): + if tags is None or data.tag in tags: + handle = data.handle + create_and_map(handle) + if data.cpu_backup_tensor is not None: + cpu_backup_tensor = data.cpu_backup_tensor + if cpu_backup_tensor is not None: + size_in_bytes = cpu_backup_tensor.numel( + ) * cpu_backup_tensor.element_size() + cpu_ptr = cpu_backup_tensor.data_ptr() + libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) + data.cpu_backup_tensor = None + + @contextmanager + def use_memory_pool(self, tag: Optional[str] = None): + """ + A context manager to use the memory pool. + All memory allocation created inside the context will be allocated + in the memory pool, and has the specified tag. + + :param tag: The tag of the memory allocation. If None, the default tag + will be used. + """ + if tag is None: + tag = CuMemAllocator.default_tag + + assert isinstance(tag, str) + + old_tag = self.current_tag + self.current_tag = tag + with use_memory_pool_with_allocator(self.python_malloc_callback, + self.python_free_callback) as data: + # start to hit another PyTorch bug in PyTorch 2.6, + # possibly because of gc-related issue w.r.t. the allocator and + # the memory pool. + # to avoid the issue, we keep a reference of the data. + # see https://github.com/pytorch/pytorch/issues/146431 . + self.allocator_and_pools[tag] = data + yield + # PyTorch's bug, calling torch.cuda.empty_cache() will error + # when using pluggable allocator, see + # https://github.com/pytorch/pytorch/issues/145168 . + # if we have some memory allocated and then freed, + # the memory will not be released. + # right now it is fine, because we only use this allocator + # during weight loading and kv cache creation, where we only + # allocate memory. + # TODO: we need to find a way to release the memory, + # i.e. calling torch.cuda.empty_cache() + self.current_tag = old_tag + + def get_current_usage(self) -> int: + """ + Get the total number of bytes allocated in the memory pool. + """ + sum_bytes: int = 0 + for ptr, data in self.pointer_to_data.items(): + handle = data.handle + sum_bytes += handle[1] + return sum_bytes diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 0000000..e911b2a --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py new file mode 100644 index 0000000..0a5a951 --- /dev/null +++ b/vllm/distributed/communication_op.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """Reduce-Scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_, dim) + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/distributed/device_communicators/__init__.py b/vllm/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py new file mode 100644 index 0000000..85f87cb --- /dev/null +++ b/vllm/distributed/device_communicators/all2all.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any + +import torch +import torch.distributed as dist + +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.utils import has_deep_ep, has_pplx + +from .base_device_communicator import All2AllManagerBase, Cache + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE +else: + FusedMoE = None + + +class NaiveAll2AllManager(All2AllManagerBase): + """ + A naive implementation of all2all communication. + It uses all-reduce under the hood, which is not + efficient at all. The main purpose is for testing and + debugging. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_world_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + self.dp_group.broadcast(buffer[start:end, :], idx) + + return buffer + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + + all_hidden_states = self.dp_group.all_reduce(hidden_states) + hidden_states = all_hidden_states[start:end, :] + return hidden_states + + def destroy(self): + pass + + +class PPLXAll2AllManager(All2AllManagerBase): + """ + All2All communication based on PPLX kernels. + """ + + def __init__(self, cpu_group): + assert has_pplx( + ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + super().__init__(cpu_group) + + if self.internode: + # inter-node communication needs nvshmem, + # intra-node communication uses p2p mapping directly + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) + logger.debug( + "Initialize NVSHMEM for pplx_kernels: " + "rank=%d, world size=%d", self.rank, self.world_size) + uid = nvshmem_get_unique_id( + ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() + dist.broadcast(uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group) + logger.debug("PPLX NVSHMEM UID = %s", uid) + nvshmem_init(uid, self.rank, self.world_size) + + self.handle_cache = Cache() + + def get_handle(self, kwargs): + import pplx_kernels as pplx + return self.handle_cache.get_or_create( + kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + + if self.internode: + from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") + nvshmem_finalize() + + +class DeepEPAll2AllManagerBase(All2AllManagerBase): + """ + All2All communication based on DeepEP High-Throughput kernels. + """ + + def __init__(self, cpu_group): + assert has_deep_ep( + ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa + super().__init__(cpu_group) + self.handle_cache = Cache() + + # This is the DeepEP default. Stick to it till we can establish + # reasonable defaults based on profiling. + self.num_sms = 20 + + def get_handle(self, kwargs): + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + +class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP High-Throughput kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs(self) -> dict[Any, Any]: + # Defaults for internode and intranode are taken from DeepEP tests. + num_nvl_bytes = 1024 * 1024 * 1024 + num_rdma_bytes = None + num_qps_per_rank = None + + if self.internode: + num_rdma_bytes = 1024 * 1024 * 1024 + num_qps_per_rank = self.num_sms // 2 + else: + num_rdma_bytes = 0 + num_qps_per_rank = 1 + + assert num_rdma_bytes is not None + assert num_qps_per_rank is not None + return dict(group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=False, + num_qps_per_rank=num_qps_per_rank) + + def get_handle(self, kwargs): + + assert len(kwargs) == 0, ( + "DeepEPHTAll2AllManager expects no arguments. All the required " + "args are computed in the Manager itself.") + + import deep_ep + buffer_kwargs = self._make_all2all_kwargs() + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.Buffer) + # It is dangerous to set num sms outside this function. num_sms is not + # a part of the hash-key that identifies this object. If we are in a + # situation where we make objects with different num_sms, the hash key + # in get_or_create must be updated. + handle.set_num_sms(self.num_sms) + return handle + + +class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP Low-Latency kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs( + self, + max_num_tokens_per_dp_rank: int, + token_hidden_size: int, + num_ep_ranks: int, + num_global_experts: int, + num_local_experts: int, + ) -> dict[Any, Any]: + """ + max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank + can dispatch all the ranks must hold the same value. + token_hidden_size: the hidden dimension of each token. + num_ep_ranks: the number of EP group ranks. + num_global_experts: Number of experts in the model. + num_local_experts: Number of experts in an EP rank. + """ + import deep_ep + + # Defaults for internode and intranode are taken from DeepEP tests. + num_nvl_bytes = 1024 * 1024 * 1024 + num_qps_per_rank = num_local_experts + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, + hidden=token_hidden_size, + num_ranks=num_ep_ranks, + num_experts=num_global_experts) + + assert num_rdma_bytes is not None + return dict(group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_qps_per_rank) + + def get_handle(self, kwargs): + """ + The kwargs for DeepEPLLAll2AllManager is dictated by + _make_all2all_kwargs. + """ + import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.Buffer) + # It is dangerous to set num sms outside this function. num_sms is not + # a part of the hash-key that identifies this object. If we are in a + # situation where we make objects with different num_sms, the hash key + # in get_or_create must be updated. + handle.set_num_sms(self.num_sms) + return handle diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py new file mode 100644 index 0000000..1bc2d8e --- /dev/null +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import Optional +from weakref import WeakValueDictionary + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Cache: + + def __init__(self): + self._cache: WeakValueDictionary = WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, kwargs, func): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + instance = func(**kwargs) + self._cache[key] = instance + return instance + + +class All2AllManagerBase: + + def __init__(self, cpu_group): + self.cpu_group = cpu_group + + # compute some common properties + from vllm.distributed.parallel_state import (get_dp_group, + get_tp_group, + in_the_same_node_as) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + + def get_handle(self, kwargs): + # get a handle for the all2all communication, + # based on the kwargs. + # different layers can have different configs, + # e.g. one layer has hidden size 1024, another has 2048. + # usually the underlying implementation caches the handle + # and reuse it for the same config. + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + +class DeviceCommunicatorBase: + """ + Base class for device-specific communicator. + It can use the `cpu_group` to initialize the communicator. + If the device has PyTorch integration (PyTorch can recognize its + communication backend), the `device_group` will also be given. + """ + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + self.device = device or torch.device("cpu") + self.cpu_group = cpu_group + self.device_group = device_group + self.unique_name = unique_name + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, + self.global_rank) + + use_ep = False + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + # as long as we use data parallel (coupled data parallel + # where all data parallel ranks execute forward together), + # we initialize the all2all manager used in expert parallel. + use_ep = config.parallel_config.data_parallel_size > 1 + + self.use_all2all = "ep" in unique_name and use_ep + self.all2all_manager: Optional[All2AllManagerBase] = None + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + pass + + def prepare_communication_buffer_for_model(self, + model: torch.nn.Module) -> None: + """ + Prepare the communication buffer for the model. + """ + if not self.use_all2all: + return + + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + module.quant_method.init_prepare_finalize(module.moe_config, + module.quant_config) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Dispatch the hidden states and router logits to the appropriate device. + This is a no-op in the base class. + """ + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Combine the hidden states and router logits from the appropriate device. + This is a no-op in the base class. + """ + return hidden_states diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py new file mode 100644 index 0000000..94effa0 --- /dev/null +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from .base_device_communicator import DeviceCommunicatorBase + + +class CpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + self.dist_module = torch.distributed + + if (current_platform.get_cpu_architecture() + == CpuArchEnum.X86) and hasattr( + torch.ops._C, + "init_shm_manager") and unique_name.startswith("tp"): + self.dist_module = _CPUSHMDistributed(self) + + def all_reduce(self, input_): + self.dist_module.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + + # Gather. + self.dist_module.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + self.dist_module.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + +class _CPUSHMDistributed: + + def __init__(self, communicator: CpuCommunicator): + instance_identifier = os.environ["VLLM_DIST_IDENT"] + unique_name = communicator.unique_name + instance_identifier = f"{instance_identifier}-{unique_name}" + self.communicator = communicator + + group_ranks = [str(rank) for rank in self.communicator.ranks] + shm_group_identifier = f"[{'-'.join(group_ranks)}]" + self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" + + self.handle = self._init_cpu_shm() + + def _init_cpu_shm(self) -> int: + handle = torch.ops._C.init_shm_manager( + self.group_name, + self.communicator.world_size, + self.communicator.rank, + ) + torch.distributed.barrier(self.communicator.device_group) + torch.ops._C.join_shm_manager( + handle, + self.group_name, + ) + torch.distributed.barrier(self.communicator.device_group) + + return handle + + def all_reduce(self, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_allreduce(self.handle, input) + + def gather(self, + input: torch.Tensor, + gather_list: Optional[list[torch.Tensor]], + dst: int = -1, + group: Optional[ProcessGroup] = None) -> None: + # Note: different from the torch gather, here we use local dst rank. + torch.ops._C.shm_gather(self.handle, input, gather_list, + torch.distributed.get_group_rank(group, dst)) + + def all_gather_into_tensor(self, + output: torch.Tensor, + input: torch.Tensor, + group: Optional[ProcessGroup] = None) -> None: + torch.ops._C.shm_all_gather(self.handle, input, output) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py new file mode 100644 index 0000000..3958d56 --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +logger = init_logger(__name__) + + +class CudaCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + if "tp" not in unique_name: + # only tp uses custom allreduce + use_custom_allreduce = False + else: + from vllm.distributed.parallel_state import ( + _ENABLE_CUSTOM_ALL_REDUCE) + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + + # ep does not use pynccl + use_pynccl = "ep" not in unique_name + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + self.qr_comm: Optional[QuickAllReduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + if current_platform.is_rocm(): + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + self.qr_comm = QuickAllReduce(group=self.cpu_group, + device=self.device) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + elif all2all_backend == "pplx": + from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + logger.info("Using PPLX all2all manager.") + elif all2all_backend == "deepep_high_throughput": + from .all2all import DeepEPHTAll2AllManager + self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + logger.info("Using DeepEP High-Throughput all2all manager.") + elif all2all_backend == "deepep_low_latency": + from .all2all import DeepEPLLAll2AllManager + self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + logger.info("Using DeepEP Low-Latency all2all manager.") + else: + raise ValueError(f"Unknown all2all backend: {all2all_backend}") + + def all_reduce(self, input_): + # always try quick reduce first, then custom allreduce, + # and then pynccl. (quick reduce just for ROCM MI3*) + qr_comm = self.qr_comm + if qr_comm is not None and not qr_comm.disabled and \ + qr_comm.should_quick_allreduce(input_): + out = qr_comm.quick_all_reduce(input_) + assert out is not None + return out + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out + + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.all2all_manager is not None: + self.all2all_manager.destroy() + self.all2all_manager = None + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( + hidden_states, router_logits) + return hidden_states, router_logits + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) + return hidden_states diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py new file mode 100644 index 0000000..2c38e8e --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("cudaIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + if so_file is None: + so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var + assert so_file is not None, \ + ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py new file mode 100644 index 0000000..5519073 --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Optional, Union + +import os +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check) +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +try: + ops.meta_size() + custom_ar = True +except Exception: + # For CPUs + custom_ar = False + +logger = init_logger(__name__) + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if envs.VLLM_SKIP_P2P_CHECK: + logger.info( + "Skipping P2P check and trusting the driver's P2P report.") + return torch.cuda.can_device_access_peer(rank, i) + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16] + + # max_size: max supported allreduce size + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 512) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-GPU environment + logger.info("Custom allreduce is disabled because " + "of missing custom allreduce library") + return + + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + + rank = dist.get_rank(group=self.group) + self.rank = rank + world_size = dist.get_world_size(group=self.group) + + if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: + return + + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + fully_connected = current_platform.is_fully_connected( + physical_device_ids) + + # if world_size > 2 and not fully_connected: + if not fully_connected: + max_size = 32 * 8192 * 2 + if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + logger.warning( + "We are using PCIe's custom allreduce." + "If the performance is poor, we can add " + "--disable-custom-all-reduce in the instruction.") + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not current_platform.is_rocm() and not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, + group=group, + uncached=True) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.fully_connected = fully_connected + self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, + self.fully_connected) + ops.register_buffer(self._ptr, self.buffer_ptrs) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + logger.info("Registering %d cuda graph addresses", len(offset)) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] + for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + return inp_size <= self.max_size + + def all_reduce(self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if out is None: + out = torch.empty_like(inp) + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], + self.max_size) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce(input, registered=False) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return torch.empty_like(input) + else: + # Note: outside of cuda graph context, custom allreduce incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + return self.all_reduce(input, registered=False) + + def close(self): + if not self.disabled and self._ptr: + if ops is not None: + ops.dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.buffer_ptrs, rank=self.rank) + + def __del__(self): + self.close() + + + @staticmethod + def create_shared_buffer(size_in_bytes: int, + group: Optional[ProcessGroup] = None, + uncached: Optional[bool] = False) -> list[int]: + pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) + + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: list[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer) # type: ignore + else: + pointers.append(ops.open_mem_handle(h)) + return pointers + + @staticmethod + def free_shared_buffer(pointers: list[int], + group: Optional[ProcessGroup] = None, + rank: Optional[int] = 0) -> None: + if rank is None: + rank = dist.get_rank(group=group) + if ops is not None: + ops.free_shared_buffer(pointers[rank]) diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000..7c6001e --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ctypes +import json +import os +import pickle +import subprocess +import sys +import tempfile +from collections.abc import Sequence +from itertools import product +from typing import Optional + +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from vllm.logger import init_logger +from vllm.utils import (cuda_device_count_stateless, + update_environment_variables) + +logger = init_logger(__name__) + + +def producer(batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer(batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process(target=producer, + args=(batch_src, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_tgt = smp.Process(target=consumer, + args=(batch_tgt, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: list[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + logger.warning( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", src, tgt) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + envs.VLLM_CACHE_ROOT, + f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + from vllm.distributed.parallel_state import get_world_group + if ((not is_distributed or get_world_group().local_rank == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache in %s", path) + cache: dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps( + (batch_src, batch_tgt, output_file.name)) + returned = subprocess.run([sys.executable, __file__], + input=input_bytes, + capture_output=True) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + logger.info("reading GPU P2P access cache from %s", path) + with open(path) as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py new file mode 100644 index 0000000..f00f6b6 --- /dev/null +++ b/vllm/distributed/device_communicators/hpu_communicator.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.distributed as dist + +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch # noqa: F401 + + +class HpuCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge + # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used + # (which is required for tensor parallel HPUGraph inference) + htorch.core.mark_step() + dist.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + htorch.core.mark_step() + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/vllm/distributed/device_communicators/neuron_communicator.py b/vllm/distributed/device_communicators/neuron_communicator.py new file mode 100644 index 0000000..5b61a16 --- /dev/null +++ b/vllm/distributed/device_communicators/neuron_communicator.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.platforms import current_platform + +if current_platform.is_neuron(): + import torch_xla.core.xla_model as xm + + +class NeuronCommunicator(DeviceCommunicatorBase): + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "Neuron only supports dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000..2948629 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import current_stream + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: Union[ProcessGroup, StatelessProcessGroup], + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def all_reduce(self, + in_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000..df9c239 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/hcus." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: + if len(data) != 128: + raise ValueError( + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py new file mode 100644 index 0000000..c61231e --- /dev/null +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import Enum +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import get_current_vllm_config +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +logger = init_logger(__name__) + +try: + ops.qr_max_size() + quick_ar = True +except Exception: + # For CPUs and CUDA + quick_ar = False + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class QuickReduceRegime(Enum): + FP = 0 + INT8 = 1 + INT6 = 2 + INT4 = 3 + NONE = 4 + + +MB = 1024 * 1024 + + +class QuickAllReduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 8] + _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # The following data is based on kernel tests. + # In this order [FP, INT8, INT6, INT4]. + _QR_MIN_SIZE = { + (torch.float16, 2): [1 * MB, 2 * MB, 2 * MB, 1 * MB], + (torch.float16, 4): [1 * MB, 16 * MB, 4 * MB, 2 * MB], + (torch.float16, 8): [16 * MB, 4 * MB, 4 * MB, 2 * MB], + (torch.bfloat16, 2): [2 * MB, 8 * MB, 8 * MB, 8 * MB], + (torch.bfloat16, 4): [8 * MB, 64 * MB, 64 * MB, 16 * MB], + (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], + } + + def __init__(self, group: ProcessGroup, + device: Union[int, str, torch.device]) -> None: + """ + Custom allreduce provides non-destructive acceleration and is + available for CUDA and ROCm MI300 series. + + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 + quantization formats and FP(float16, bfloat16). + + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. + + Only the ROCm MI300 series is supported for quick allreduce at + this time. + + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self.disabled = True + if not self._rocm_arch_available(): + logger.debug( + "Custom quick allreduce is only supported on ROCm MI300 series." + ) + return + + if not quick_ar: + # disable because of missing quick reduce library + # e.g. in a cuda environment + logger.info("Custom quick allreduce is disabled because " + "of missing custom quick allreduce library") + return + + self.group = group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "Custom quick allreduce should be attached to a non-NCCL group.") + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom quick allreduce for + # multi-node case. + logger.warning("Custom quick allreduce is disabled because this " + "process group spans across nodes.") + return + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + self.rank = rank + self.world_size = world_size + if world_size == 1: + # No need to initialize QuickReduce for single GPU case. + return + + if world_size not in QuickAllReduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom quick allreduce is disabled due to an " + "unsupported world size: %d. Supported world sizes: %s.", + world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(self.world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom quick allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda_alike() + self.fully_connected = current_platform.is_fully_connected( + physical_device_ids) + if self.world_size > 2 and not self.fully_connected: + logger.debug( + "Custom quick allreduce is disabled because it's not supported " + "on more than two PCIe-only GPUs. ") + return + + self.init_quick_all_reduce() + + def init_quick_all_reduce(self): + # On RocM, bfloat16 kernels are slower than fp16 + # due to slower match operations + # If environment variable is set to 1, we convert input to fp16 + self.use_fp16_kernels = envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 + regime_str = envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION + if regime_str not in QuickReduceRegime.__members__: + logger.warning( + "Custom quick allreduce:", + f"Invalid quantization level: {regime_str}. " + "Supported levels: " + f"{list(QuickReduceRegime.__members__.keys())}") + return + + if regime_str == "NONE": + logger.debug("Custom quick allreduce is disabled based " + "on env variable " + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'") + return + self.qr_quant_level = QuickReduceRegime[regime_str] + vllm_config = get_current_vllm_config() + if vllm_config is not None and \ + hasattr(vllm_config, "model_config") and \ + hasattr(vllm_config.model_config, "dtype"): + dtype = vllm_config.model_config.dtype + if dtype not in [torch.float16, torch.bfloat16]: + logger.debug( + "Custom quick allreduce disabled: only supports " + "float16 and float16, but get %s.", dtype) + return + + if dtype == torch.bfloat16 and self.use_fp16_kernels: + logger.info( + "Custom quick allreduce: BF16 inputs will be converted " + "to FP16 to improve performance. set " + "envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 " + "to turn off.") + + # VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB + qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB + if qr_max_size is not None: + if qr_max_size < 1: + logger.info( + "You should not set a max_size smaller than 1MB, which can " + "lead to error or degradation to custom allreduce or rccl." + ) + qr_max_size = qr_max_size * MB + self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) + self.qr_max_size = qr_max_size if qr_max_size is not None \ + else ops.qr_max_size() + self.create_shared_buffer() + self.disabled = False + + def _rocm_arch_available(self): + if not current_platform.is_rocm(): + return False + try: + props = torch.cuda.get_device_properties(0) + gcn_arch = getattr(props, "gcnArchName", "") + supported_archs = ['gfx94', 'gfx95'] + return any(gfx in gcn_arch for gfx in supported_archs) + except Exception as e: + logger.warning("Failed to determine ROCm for quick allreduce: %s", + e) + return False + + def create_shared_buffer(self): + """ + Creates a shared buffer for quickreduce. + Has to be called after init_custom_qr + """ + handle = ops.qr_get_handle(self._ptr) + world_size = dist.get_world_size(group=self.group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=self.group) + ops.qr_open_handles(self._ptr, handles) + + def should_quick_allreduce(self, inp: torch.Tensor): + """ + Check if quickreduce is available + """ + if self.disabled: + return False + if inp.dtype not in self._SUPPORTED_DTYPES: + return False + inp_size = inp.numel() * inp.element_size() + # custom quick allreduce requires input byte size to be + # multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + dtype = inp.dtype + if self.use_fp16_kernels: + dtype = torch.float16 + return inp_size <= self.qr_max_size and \ + inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ + [self.qr_quant_level.value] + + def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): + """Performs an out-of-place custom quick all reduce.""" + # quick allreduce doesn't require a separate graph mode, + # as QR uses static IPC buffer. + if out is None: + out = torch.empty_like(inp) + ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value, + self.use_fp16_kernels) + return out + + def close(self): + if not self.disabled and getattr(self, "_ptr", None): + if ops is not None: + ops.qr_destroy(self._ptr) + self._ptr = 0 + self.disabled = True + + def __del__(self): + self.close() diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py new file mode 100644 index 0000000..c781004 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from multiprocessing import shared_memory +from threading import Event +from typing import Any, Optional, Union +from unittest.mock import patch + +import torch +import torch.distributed as dist +import zmq +from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore + +import vllm.envs as envs +from vllm.distributed.utils import StatelessProcessGroup, sched_yield +from vllm.logger import init_logger +from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, + is_valid_ipv6_address) + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + +logger = init_logger(__name__) + + +class SpinTimer: + + def record_activity(self): + pass + + def spin(self): + sched_yield() + + +class SpinSleepTimer(SpinTimer): + """ + In setups which have long inactivity periods it is desirable to reduce + system power consumption when vllm does nothing. This would lead to more + CPU thermal headroom when a request eventually comes, especially when + multiple GPUs are connected as each GPU would otherwise pin one thread at + 100% CPU usage. + + The simplest solution is to reduce polling frequency when there is no + activity for a certain period of time. + """ + + def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): + self.last_activity = time.monotonic() + self.busy_loop_s = busy_loop_s + self.wait_sleep_s = wait_sleep_s + + def record_activity(self): + self.last_activity = time.monotonic() + + def spin(self): + curr_time = time.monotonic() + if curr_time >= self.last_activity + self.busy_loop_s: + time.sleep(self.wait_sleep_s) + else: + sched_yield() + + +class ShmRingBuffer: + + def __init__(self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """# noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = (self.max_chunk_bytes + + self.metadata_size) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer) + # initialize the metadata section to 0 + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + try: + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert (self.shared_memory.size + >= self.total_bytes_of_buffer) + except FileNotFoundError: + # we might deserialize the object in a different node + # in this case, this object is not used, + # and we should suppress the error + pass + + def handle(self): + return (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name) + + def __reduce__(self): + return ( + self.__class__, + self.handle(), + ) + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +@dataclass +class Handle: + local_reader_ranks: list[int] = field(default_factory=list) + + buffer_handle: Optional[tuple[int, int, int, str]] = None + local_subscribe_addr: Optional[str] = None + remote_subscribe_addr: Optional[str] = None + remote_addr_ipv6: bool = False + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[list[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) + + self.current_idx = 0 + else: + self.buffer = None # type: ignore + local_subscribe_addr = None + self.local_socket = None + self.current_idx = -1 + + remote_addr_ipv6 = False + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True + connect_ip = f"[{connect_ip}]" + socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" + else: + remote_subscribe_addr = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + self._read_spin_timer = SpinTimer() + + self.handle = Handle( + local_reader_ranks=local_reader_ranks, + buffer_handle=self.buffer.handle() + if self.buffer is not None else None, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer_handle is not None + self.buffer = ShmRingBuffer(*handle.buffer_handle) + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = handle.local_subscribe_addr + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) + + self.remote_socket = None + + self._read_spin_timer = SpinSleepTimer( + ) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if handle.remote_addr_ipv6: + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = handle.remote_subscribe_addr + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + # wait for subscription messages from all local readers + self.local_socket.recv() + if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + # wait for subscription messages from all remote readers + self.remote_socket.recv() + if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working + self.remote_socket.send(b"READY") + elif self._is_local_reader: + # wait for the writer to send a message + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + # wait for the writer to send a message + recv = self.remote_socket.recv() + assert recv == b"READY" + + @contextmanager + def acquire_write(self, timeout: Optional[float] = None): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # Release the processor to other threads + sched_yield() + + # if we wait for a long time, log a message + if (time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.debug( + ("No available shared memory broadcast block found" + " in %s second."), + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + @contextmanager + def acquire_read(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): + assert self._is_local_reader, "Only readers can acquire read" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.local_reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # Release the processor to other threads + self._read_spin_timer.spin() + + # if we wait for a long time, log a message + if (time.monotonic() - start_time + > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.debug( + ("No available shared memory broadcast block found" + " in %s second."), + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) + n_warning += 1 + + if cancel is not None and cancel.is_set(): + raise RuntimeError("cancelled") + + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.local_reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + + self._read_spin_timer.record_activity() + break + + def enqueue(self, obj, timeout: Optional[float] = None): + """ Write to message queue with optional timeout (in seconds) """ + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write(timeout) as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write(timeout) as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) + + def dequeue(self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None): + """ Read from message queue with optional timeout (in seconds) """ + if self._is_local_reader: + with self.acquire_read(timeout, cancel) as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + obj = MessageQueue.recv(self.local_socket, timeout) + elif self._is_remote_reader: + obj = MessageQueue.recv(self.remote_socket, timeout) + else: + raise RuntimeError("Only readers can dequeue") + return obj + + @staticmethod + def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any: + timeout_ms = None if timeout is None else int(timeout * 1000) + if not socket.poll(timeout=timeout_ms): + raise TimeoutError + recv = socket.recv(copy=False) + return pickle.loads(recv.buffer) + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + @staticmethod + def create_from_process_group(pg: Union[ProcessGroup, + StatelessProcessGroup], + max_chunk_bytes, + max_chunks, + writer_rank=0) -> "MessageQueue": + if isinstance(pg, ProcessGroup): + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + else: + group_rank = pg.rank + group_world_size = pg.world_size + global_ranks = list(range(pg.world_size)) + + from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue + if group_rank == writer_rank: + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + if isinstance(pg, ProcessGroup): + dist.broadcast_object_list([handle], + src=global_ranks[writer_rank], + group=pg) + else: + pg.broadcast_obj(handle, writer_rank) + else: + if isinstance(pg, ProcessGroup): + recv = [None] + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) + handle = recv[0] # type: ignore + else: + handle = pg.broadcast_obj(None, writer_rank) + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000..c60a7a7 --- /dev/null +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .base_device_communicator import DeviceCommunicatorBase + +USE_RAY = parallel_config = get_current_vllm_config( +).parallel_config.distributed_executor_backend == "ray" + +logger = init_logger(__name__) + +if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + + if USE_RAY: + from vllm.executor import ray_utils + + +class TpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = self.global_rank + global_world_size = self.global_world_size + + if USE_RAY: + logger.info("TpuCommunicator initialized with RAY") + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + else: + logger.info("TpuCommunicator initialized with MP") + # Sanity: Verify we run on a single host + num_hosts = torch_xla.tpu.num_tpu_workers() + assert num_hosts == 1 + + # Get the current number of TPUs (we have locally) + local_world_size = torch_xla.tpu.num_available_chips() + + # Get current rank + local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + + pjrt.initialize_multiprocess(local_rank, local_world_size) + xr._init_world_size_ordinal() + self.groups = create_optimized_replica_groups() + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + # TODO: Remove the groups specification after XLA compiler can support + # auto-reordering the ring order for all-reduce. + return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(input_, dim=dim) + + +try: + from tpu_commons.distributed.device_communicators import ( + TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator = TpuCommonsCommunicator # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + pass diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 0000000..216ff85 --- /dev/null +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class XpuCommunicator(DeviceCommunicatorBase): + + def __init__(self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = ""): + super().__init__(cpu_group, device, device_group, unique_name) + + def all_reduce(self, input_) -> torch.Tensor: + dist.all_reduce(input_, group=self.device_group) + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((self.world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + if self.rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + else: + output_tensor = None + return output_tensor diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py new file mode 100644 index 0000000..8051102 --- /dev/null +++ b/vllm/distributed/eplb/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +''' +Expert parallelism load balancer (EPLB). +''' + +from .eplb_state import * +from .rebalance_algo import * diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py new file mode 100644 index 0000000..6b0a126 --- /dev/null +++ b/vllm/distributed/eplb/eplb_state.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Expert parallelism load balancer (EPLB) metrics and states. + +# Glossary + +- **Logical Expert**: An expert that is part of the model's logical structure. + It holds a set of weights and is replicated across multiple physical + experts. +- **Redundant Expert**: To achieve load balancing, for some popular logical + experts, we create additional copies of the expert weights. During inference, + each of these copies can be routed to by the same set of tokens. +- **Physical Expert**: An expert that is instantiated on a specific device. + It is a replica of a logical expert and can be rearranged across devices. + I.e., one logical expert may have multiple sets of weights initialized on + different devices, and each of these sets is a physical expert. +- **Local Physical Expert**: A physical expert that is instantiated on the + current device. + +For example: DeepSeek-R1 has 256 logical experts, so each MoE layer +has 256 sets of linear layer weights in the model parameters. If we add 32 +redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in +total. And when deploying, we'll have 288 sets of linear layer weights for each +MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local +physical experts. +""" + +import time +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +from torch.distributed import all_gather, all_reduce + +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_ep_group, get_node_count +from vllm.logger import init_logger +from vllm.model_executor.models.interfaces import MixtureOfExperts + +from .rebalance_algo import rebalance_experts +from .rebalance_execute import rearrange_expert_weights_inplace + +logger = init_logger(__name__) + + +@dataclass +class EplbState: + """EPLB metrics.""" + + physical_to_logical_map: torch.Tensor + """ + Mapping from physical experts to logical experts. + + Shape: (num_moe_layers, num_physical_experts) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[0, 1, 2, 3, 0, 1], + [0, 2, 0, 1, 0, 3]] + ``` + """ + logical_to_physical_map: torch.Tensor + """ + Mapping from logical experts to physical experts. + + This is a sparse matrix, where -1 indicates no mapping. + + Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[[0, 4, -1], + [1, 5, -1], + [2, -1, -1], + [3, -1, -1]], + [[0, 2, 4], + [3, -1, -1], + [1, -1, -1], + [5, -1, -1]]] + ``` + """ + logical_replica_count: torch.Tensor + """ + Number of replicas for each logical expert. + This is exactly the non-`-1` count in the `logical_to_physical_map`. + + Shape: (num_moe_layers, num_logical_experts) + + # Example + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the count could look like this: + + ``` + [[2, 2, 1, 1], + [3, 1, 1, 1]] + """ + + expert_load_pass: torch.Tensor + """ + Expert load during this forward pass. + We use the token count each expert processes as the load. + + Shape: (num_moe_layers, num_local_physical_experts) + """ + expert_load_window: torch.Tensor + """ + A sliding window of expert load. + + Shape: (window_size, num_moe_layers, num_local_physical_experts) + """ + expert_load_window_step: int = 0 + """ + Current step in the sliding window. + + Different from `expert_rearrangement_step`, each EP rank may have its own + `expert_load_window_step`. + """ + expert_load_window_size: int = 0 + """ + Size of the expert load sliding window. + This is a constant and is taken from the config. + """ + + expert_rearrangement_step: int = 0 + """ + Steps after last rearrangement. + Will trigger a rearrangement if it exceeds the threshold. + + NOTE: Keep in mind that all EP ranks need to have the same + `expert_rearrangement_step` value to ensure synchronization. + Otherwise, the rearrangement will hang at collective + communication calls. + """ + expert_rearrangement_step_interval: int = 0 + """ + Interval for expert rearrangement steps. + This is a constant and is taken from the config. + """ + + @staticmethod + def build_initial_global_physical_to_logical_map( + num_routed_experts: int, + num_redundant_experts: int, + ) -> Sequence[int]: + """ + Build an initial expert arrangement using the following structure: + [original routed experts, redundant experts] + + Returns: + physical_to_logical_map (Sequence[int]): A list of integers, + where each integer is the index of the logical expert + that the corresponding physical expert maps to. + """ + global_physical_to_logical_map = list(range(num_routed_experts)) + global_physical_to_logical_map += [ + i % num_routed_experts for i in range(num_redundant_experts) + ] + return global_physical_to_logical_map + + @classmethod + def build( + cls, + model: MixtureOfExperts, + device: torch.device, + parallel_config: ParallelConfig, + ) -> "EplbState": + """ + Build the initial EPLB state. + """ + physical_to_logical_map_list = ( + cls.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + )) + physical_to_logical_map = torch.tensor( + physical_to_logical_map_list, + device=device, + ) + logical_to_physical_map = torch.full( + (model.num_logical_experts, model.num_redundant_experts + 1), + -1, + device=device, + ) + logical_replica_count = torch.zeros( + (model.num_logical_experts, ), + device=device, + dtype=torch.long, + ) + + for i in range(model.num_physical_experts): + logical_idx = physical_to_logical_map[i] + logical_to_physical_map[logical_idx, + logical_replica_count[logical_idx]] = i + logical_replica_count[logical_idx] += 1 + + # Duplicate initial mapping for all layers + physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + -1, + ).contiguous() + logical_replica_count = logical_replica_count.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + + expert_load_pass = torch.zeros( + (model.num_moe_layers, model.num_local_physical_experts), + dtype=torch.int32, + device=device, + ) + expert_load_window_size = parallel_config.eplb_window_size + expert_load_window = torch.zeros( + (expert_load_window_size, model.num_moe_layers, + model.num_local_physical_experts), + dtype=torch.int32, + device=device, + ) + + # Set the initial progress of rearrangement to 3/4 + eplb_step_interval = parallel_config.eplb_step_interval + expert_rearrangement_step = max( + 0, eplb_step_interval - eplb_step_interval // 4) + + model.set_eplb_state( + expert_load_pass, + logical_to_physical_map, + logical_replica_count, + ) + + return cls( + physical_to_logical_map, + logical_to_physical_map, + logical_replica_count, + expert_load_pass, + expert_load_window, + expert_load_window_size=expert_load_window_size, + expert_rearrangement_step=expert_rearrangement_step, + expert_rearrangement_step_interval=eplb_step_interval, + ) + + def step(self, + model: MixtureOfExperts, + is_dummy: bool = False, + is_profile: bool = False, + log_stats: bool = False) -> None: + """ + Step the EPLB state. + + Args: + model (MixtureOfExperts): The MoE model. + is_dummy (bool): If `True`, this is a dummy step and the load + metrics recorded in this forward pass will not count. Defaults + to `False`. + is_profile (bool): If `True`, perform a dummy rearrangement + with maximum communication cost. This is used in `profile_run` + to reserve enough memory for the communication buffer. + log_stats (bool): If `True`, log the expert load metrics. + + # Stats + The metrics are all summed up across layers. + - `avg_tokens`: The average load across ranks. + - `max_tokens`: The maximum load across ranks. + - `balancedness`: The ratio of average load to maximum load. + """ + + if is_profile: + self.rearrange(model, is_profile=True) + return + + if is_dummy: + # Do not record load metrics for dummy steps + self.expert_load_pass.zero_() + + if log_stats: + # `num_tokens`: (num_moe_layers,) + num_tokens = self.expert_load_pass.sum(dim=-1) + + # Collect load metrics from all ranks + ep_group = get_ep_group().device_group + num_tokens_list = [ + torch.empty_like(num_tokens) for _ in range(ep_group.size()) + ] + all_gather(num_tokens_list, num_tokens, group=ep_group) + # Stack to get (num_ranks, num_moe_layers) + num_tokens_per_rank = torch.stack(num_tokens_list).float() + + # Compute balancedness ratio: + # for each layer: + # (mean load across ranks) / (max load across ranks) + avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum( + dim=0) + + # Just to make type checker happy + tokens_tensors: list[float] = torch.stack( + [avg_tokens_tensor, max_tokens_tensor]).tolist() + avg_tokens, max_tokens = tokens_tensors + balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + + if ep_group.rank() == 0: + logger.info( + "EPLB step: avg_tokens=%.2f, max_tokens=%d, " + "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + + # Update the expert load sliding window + if not is_dummy: + self.expert_load_window[self.expert_load_window_step] = ( + self.expert_load_pass.clone()) + self.expert_load_window_step += 1 + if self.expert_load_window_step >= self.expert_load_window_size: + self.expert_load_window_step = 0 + self.expert_load_pass.zero_() + + # Step the expert rearrangement step + # Note that even if this is a dummy step, we still increment the + # rearrangement step and perform rearrangement to ensure all ranks are + # performing collective communication. + self.expert_rearrangement_step += 1 + if (self.expert_rearrangement_step + >= self.expert_rearrangement_step_interval): + self.expert_rearrangement_step = 0 + self.rearrange(model) + + def rearrange(self, + model: MixtureOfExperts, + is_profile: bool = False) -> None: + """ + Rearrange the experts according to the current load. + """ + + ep_group = get_ep_group().device_group + ep_rank = ep_group.rank() + + time_start = None + is_main_rank = ep_rank == 0 + if is_main_rank: + torch.cuda.synchronize() + time_start = time.perf_counter() + logger.info("Rearranging experts %s...", + "(profile)" if is_profile else "") + + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] + + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) + + # TODO(bowen): Treat differently for prefill and decode nodes + num_replicas = model.num_physical_experts + num_groups = model.num_expert_groups + num_nodes = get_node_count() + num_gpus = ep_group.size() + + if num_gpus % num_nodes != 0: + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load_window, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + # Update expert weights + rearrange_expert_weights_inplace( + self.physical_to_logical_map, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + is_profile, + ) + + if not is_profile: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + self.logical_to_physical_map.copy_(new_logical_to_physical_map) + self.logical_replica_count.copy_(new_logical_replica_count) + + if is_main_rank: + assert time_start is not None + torch.cuda.synchronize() + time_end = time.perf_counter() + logger.info( + "Rearranged experts%sin %.2f seconds.", + " (profile) " if is_profile else " ", + time_end - time_start, + ) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py new file mode 100644 index 0000000..879b5b9 --- /dev/null +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -0,0 +1,234 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Expert parallelism load balancer (EPLB) for vLLM. + +This module implements the core rearrangement algorithm. + +The rearrangement algorithm is adapted from +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). + +Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example +on how the EPLB algorithm works. +""" + +import torch + + +def balanced_packing(weight: torch.Tensor, + num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly + n/m objects and the weights of all packs are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), + dtype=torch.int64, + device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, + fill_value=-1, + dtype=torch.int64, + device="cpu") + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i + for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts( + weight: torch.Tensor, + num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum + load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, + device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, + device=perm.device).expand(perm.shape), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing( + tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * + group_size).unsqueeze(-1) + + torch.arange(group_size, + dtype=torch.int64, + device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, + num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all + logical experts + num_replicas: number of physical experts, must be a multiple of + `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of + each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica + indices for each expert + expert_count: [layers, num_logical_experts], number of physical + replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus) + num_redundant_experts = num_replicas - num_logical_experts + maxlogcnt = num_redundant_experts + 1 + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, + device=log2phy.device).expand(num_layers, -1), + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py new file mode 100644 index 0000000..2ef8587 --- /dev/null +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +The actual execution of the rearrangement. + +This involves the exchange of expert weights between GPUs. +""" + +from collections.abc import Iterable, MutableSequence, Sequence +from functools import partial + +import torch +from torch.distributed import (P2POp, ProcessGroup, all_gather, + batch_isend_irecv, get_global_rank) + + +def idx_local_to_global( + local_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a local expert index to a global expert index. + """ + return ep_rank * local_cnt + local_idx + + +def idx_global_to_local( + global_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a global expert index to a local expert index. + """ + return global_idx - ep_rank * local_cnt + + +def global_idx_to_rank( + global_idx: int, + local_cnt: int, +) -> int: + """ + Convert a global expert index to a rank index. + """ + return global_idx // local_cnt + + +def get_ep_ranks_with_expert( + idx: int, + num_local_experts: int, + old_indices: Sequence[int], + new_indices: Sequence[int], +) -> tuple[MutableSequence[int], MutableSequence[int]]: + """ + Get the ranks of the experts that need to be exchanged. + + Args: + idx: The index of the expert. + num_local_experts: The number of local experts. + old_indices: The old indices of the experts. + new_indices: The new indices of the experts. + + Returns: + A tuple of two lists: + - The ranks of the experts that need to be sent. + - The ranks of the experts that need to be received. + """ + global2rank = partial( + global_idx_to_rank, + local_cnt=num_local_experts, + ) + + ranks_to_send: list[int] = [] + ranks_to_recv: list[int] = [] + + for i, e in enumerate(old_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_send or ranks_to_send[-1] != rank: + ranks_to_send.append(rank) + + for i, e in enumerate(new_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_recv or ranks_to_recv[-1] != rank: + ranks_to_recv.append(rank) + + # Remove those ranks that can get this expert locally. + ranks_to_send_set = set(ranks_to_send) + ranks_to_recv_actual = [ + rank for rank in ranks_to_recv if rank not in ranks_to_send_set + ] + + return ranks_to_send, ranks_to_recv_actual + + +def shuffle_layer( + num_local_experts: int, + ep_rank: int, + old_indices: Sequence[int], + new_indices: Sequence[int], + expert_weights: Iterable[torch.Tensor], + expert_weights_buffer: Sequence[torch.Tensor], + ep_group: ProcessGroup, +) -> None: + """ + Perform expert weights rearrangement of one layer. + """ + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) + + # 0. Do nothing for experts that did not change. + is_unchanged = [ + old_indices[local2global(i)] == new_indices[local2global(i)] + for i in range(num_local_experts) + ] + + # 1. Perform weight copy inside the local rank. + is_received_locally = is_unchanged[:] + for src in range(num_local_experts): + src_global = local2global(src) + for dst in range(num_local_experts): + dst_global = local2global(dst) + if is_received_locally[dst]: + continue + if old_indices[src_global] == new_indices[dst_global]: + is_received_locally[dst] = True + for weight, buffer in zip(expert_weights, + expert_weights_buffer): + buffer[dst].copy_(weight[src]) + + p2p_ops: list[P2POp] = [] + + # 2. Initiate sending of weights. + experts_send_loc: dict[int, int] = {} + for src in range(num_local_experts): + expert = old_indices[local2global(src)] + if expert in experts_send_loc: + continue + experts_send_loc[expert] = src + + # We need to sort here to match send/recv + for expert, src in sorted(experts_send_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the ranks to send by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + + # Tackle remainders + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + + for dst in recv_ranks: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp( + torch.distributed.isend, + weight[src], + dst_global, + ) for weight in expert_weights + ] + + # 3. Initiate receiving of weights. + experts_recv_loc: dict[int, int] = {} + for dst in range(num_local_experts): + if is_received_locally[dst]: + continue + expert = new_indices[local2global(dst)] + if expert in experts_recv_loc: + continue + experts_recv_loc[expert] = dst + + # We need to sort here to match send/recv + for expert, dst in sorted(experts_recv_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the rank to recv by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src = ranks_to_send[recver_pos - remainder_start] + + src_global = get_global_rank(ep_group, src) + p2p_ops += [ + P2POp( + torch.distributed.irecv, + weight[dst], + src_global, + ) for weight in expert_weights_buffer + ] + + # 4. Execute the P2P operations. The real communication happens here. + if p2p_ops: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + # 5. Copy the weights from the buffer back to the original weights. + for dst in range(num_local_experts): + if is_unchanged[dst]: + continue + if is_received_locally[dst]: + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[dst]) + else: + expert = new_indices[local2global(dst)] + src = experts_recv_loc[expert] + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[src]) + + +def rearrange_expert_weights_inplace( + old_global_expert_indices: torch.Tensor, + new_global_expert_indices: torch.Tensor, + expert_weights: Sequence[Iterable[torch.Tensor]], + ep_group: ProcessGroup, + is_profile: bool = False, +) -> None: + """ + Rearranges the expert weights in place according to the new expert indices. + + The value of the indices arguments are logical indices of the experts, + while keys are physical. + + Args: + old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + expert_weights: A sequence of shape (num_moe_layers)(weight_count) + of tensors of shape (num_local_physical_experts, hidden_size_i). + For example, a linear layer may have up and down projection, + so weight_count = 2. Each weight's hidden size can be different. + ep_group: The device process group for expert parallelism. + is_profile (bool): If `True`, do not perform any actual weight copy. + This is used during profile run, where we only perform dummy + communications to reserve enough memory for the buffers. + """ + num_moe_layers, num_physical_experts = old_global_expert_indices.shape + assert len(expert_weights) == num_moe_layers + + num_local_physical_experts = next(iter(expert_weights[0])).shape[0] + assert new_global_expert_indices.shape == (num_moe_layers, + num_physical_experts) + + ep_rank = ep_group.rank() + ep_size = ep_group.size() + assert num_physical_experts == ep_size * num_local_physical_experts + + # A buffer to hold the expert weights in one layer during the exchange. + # NOTE: Currently we assume the same weights across different layers + # have the same shape. + expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] + + if is_profile: + # Maximum send size is to send all local experts to all ranks, + # So we use a dummy `all_gather` to reserve enough communication buffer + for weight, buffer in zip(expert_weights[0], expert_weights_buffer): + # A `/dev/null`-like buffer to avoid real memory allocation + dummy_recv_buffer = [buffer for _ in range(ep_size)] + # NOTE(bowen): Needed this barrier to avoid OOM during actual + # execution. I'm not very sure why this is needed + torch.distributed.barrier() + all_gather( + dummy_recv_buffer, + weight, + group=ep_group, + ) + return + + for layer in range(num_moe_layers): + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() + shuffle_layer( + num_local_physical_experts, + ep_rank, + old_global_expert_indices[layer].tolist(), + new_global_expert_indices[layer].tolist(), + expert_weights[layer], + expert_weights_buffer, + ep_group, + ) + + +__all__ = ["rearrange_expert_weights_inplace"] diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py new file mode 100644 index 0000000..2d79357 --- /dev/null +++ b/vllm/distributed/kv_events.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import asdict +from itertools import count +from queue import Queue +from typing import Any, Callable, Optional, Union + +import msgspec +import zmq + +from vllm.config import KVEventsConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class EventBatch( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] +): + ts: float + events: list[Any] + data_parallel_rank: Optional[int] = None + + +class KVCacheEvent( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] + tag=True): + """Base class for all KV cache-related events""" + + +class BlockStored(KVCacheEvent): + block_hashes: list[int] + parent_block_hash: Optional[int] + token_ids: list[int] + block_size: int + lora_id: Optional[int] + + +class BlockRemoved(KVCacheEvent): + block_hashes: list[int] + + +class AllBlocksCleared(KVCacheEvent): + pass + + +class KVEventBatch(EventBatch): + events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + + +class EventPublisher(ABC): + """Lightweight publisher for EventBatch batches with data parallelism + support. + + In data parallel setups, each DP rank runs its own EventPublisher instance + to avoid duplicate events and ensure proper event attribution: + + - Each DP rank creates a separate publisher + - Publishers automatically annotate events with their data_parallel_rank + - This allows consumers to distinguish events from different DP ranks + + The publisher is responsible for adding DP metadata since the scheduler + operates independently of DP topology and shouldn't need DP awareness. + """ + + def __init__(self, data_parallel_rank: int = 0) -> None: + self._data_parallel_rank = data_parallel_rank + + @abstractmethod + def publish(self, events: EventBatch) -> None: + """Emit events in order. + + Implementations should guarantee at-least-once delivery and + monotonic ordering (e.g., via sequence numbers). + """ + + @abstractmethod + def shutdown(self) -> None: + """Shutdown the publisher.""" + + +class NullEventPublisher(EventPublisher): + """No-op implementation (default when disabled).""" + + def publish(self, events) -> None: + return + + def shutdown(self) -> None: + return + + +class ZmqEventPublisher(EventPublisher): + """Reliable PUB/ROUTER publisher with an in-memory replay buffer. + + Spawns a separate thread to handle publishing from a queue. + + Parameters + ---------- + endpoint: + PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to + connect. + replay_endpoint: + Optional ROUTER address for replay requests. When given, subscribers can + request missed batches by sending the starting sequence number as an + 8-byte big-endian integer. + buffer_steps: + Number of past batches to keep for replay. + hwm: + ZeroMQ high-water-mark for PUB socket. + max_queue_size: + Maximum number of events to buffer in memory. + topic: + Topic to publish events to. + """ + SHUTDOWN_TIMEOUT: float = 1.0 + END_SEQ = (-1).to_bytes(8, "big", signed=True) + + def __init__( + self, + data_parallel_rank: int, + endpoint: str = "tcp://*:5557", + replay_endpoint: Optional[str] = None, + buffer_steps: int = 10_000, + hwm: int = 100_000, + max_queue_size: int = 100_000, + topic: str = "", + ) -> None: + # Storage + super().__init__(data_parallel_rank) + self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) + self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) + + # ZMQ sockets + self._ctx = zmq.Context.instance() + self._pub: Optional[zmq.Socket] = None + self._replay: Optional[zmq.Socket] = None + self._dp_rank = data_parallel_rank + + self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) + self._replay_endpoint = self.offset_endpoint_port( + replay_endpoint, self._dp_rank) + self._hwm = hwm + self._socket_setup() + + # Payload + self._seq_gen = count() + self._topic_bytes = topic.encode('utf-8') + + # Thread + self._running = True + logger.info("Starting ZMQ publisher thread") + + self._thread = threading.Thread(target=self._publisher_thread, + daemon=True, + name="zmq-publisher") + self._thread.start() + + def publish(self, events: EventBatch) -> None: + if not self._running: + raise RuntimeError("Publisher is closed") + if events.data_parallel_rank is None: + events.data_parallel_rank = self._data_parallel_rank + self._event_queue.put(events) + + def shutdown(self) -> None: + """Stop the publisher thread and clean up resources.""" + self._running = False + self._event_queue.put_nowait(None) + + start = time.time() + pending_items = True + while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT): + pending_items = not self._event_queue.empty() + if pending_items: + time.sleep(0.1) + + if pending_items: + logger.warning( + "Warning: Queue still has %s items after %s seconds timeout", + self._event_queue.qsize(), + self.SHUTDOWN_TIMEOUT, + ) + + if self._thread.is_alive(): + self._thread.join(timeout=self.SHUTDOWN_TIMEOUT) + + # Clean up ZMQ resources + try: + if self._pub is not None: + self._pub.close(linger=0) + if self._replay is not None: + self._replay.close(linger=0) + finally: + pass # Do not terminate context; other sockets may use it + + def _socket_setup(self) -> None: + """Initialize sockets + https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety + """ + if self._pub is None: + self._pub = self._ctx.socket(zmq.PUB) + self._pub.set_hwm(self._hwm) + # Heuristic: bind if wildcard / * present, else connect. + # bind stable, connect volatile convention + if (self._endpoint is not None + and ("*" in self._endpoint or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://"))): + self._pub.bind(self._endpoint) + elif self._endpoint is not None: + self._pub.connect(self._endpoint) + + # Set up replay socket: use ROUTER + # 1) handles multiple REQ clients (identities) + # 2) lets us send back one request → many replies (streamed events) + # 3) works in our non‑blocking poll loop alongside PUB + if self._replay_endpoint is not None: + self._replay = self._ctx.socket(zmq.ROUTER) + self._replay.bind(self._replay_endpoint) + + def _publisher_thread(self) -> None: + """Background thread that processes the event queue.""" + self._pack = msgspec.msgpack.Encoder() + + assert self._pub is not None # narrows type for mypy + + while self._running or self._event_queue.qsize() > 0: + # --- replay (non-critical) --------------------------------- + if self._replay is not None and self._replay.poll(0): + try: + self._service_replay() + except Exception as e: + logger.exception("Error in replay: %s", e) + + # --- main queue (critical) --------------------------------- + try: + event = self._event_queue.get(timeout=0.1) + if event is None: + break # Sentinel received, exit thread + except queue.Empty: + continue + + try: + seq = next(self._seq_gen) + + payload = self._pack.encode(event) + seq_bytes = seq.to_bytes(8, "big") + self._pub.send_multipart( + (self._topic_bytes, seq_bytes, payload)) + + self._buffer.append((seq, payload)) + self._event_queue.task_done() + + except Exception as e: + # Publishing failed; back-off a bit to avoid a tight error loop + logger.exception("Error in publisher thread: %s", e) + time.sleep(0.1) + + def _service_replay(self) -> None: + """If a replay request is waiting, send buffered batches.""" + assert self._replay is not None # narrows type for mypy + + frame = self._replay.recv_multipart() + if len(frame) != 3: + logger.warning("Invalid replay request: %s", frame) + return + client_id, _, start_seq_bytes = frame + start_seq = int.from_bytes(start_seq_bytes, "big") + + for seq, buf in self._buffer: + if seq >= start_seq: + # [identity, empty_delim, seq_bytes, payload] + # (identity, empty_delim) are stripped off by the router + # receiving payload is (seq_bytes, payload) + self._replay.send_multipart( + (client_id, b"", seq.to_bytes(8, "big"), buf)) + # Send end of sequence marker + # receiving payload is (-1, b""") + self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) + + @staticmethod + def offset_endpoint_port(endpoint: Optional[str], + data_parallel_rank: int) -> Optional[str]: + """Helper function to offset the port in an endpoint by + the data parallel rank. + + Args: + endpoint: The endpoint string + (e.g., "tcp://*:5557" or "inproc://cache") + data_parallel_rank: The data parallel rank to offset by + + Returns: + The endpoint with the port offset by data_parallel_rank + or suffix appended + """ + # Do nothing if input is None or data_parallel_rank is 0 + if not endpoint or data_parallel_rank == 0: + return endpoint + + if "inproc" in endpoint: + return f"{endpoint}_dp{data_parallel_rank}" + if "tcp" in endpoint: + if endpoint and ":" in endpoint: + # Get everything after the last colon (the port) + last_colon_idx = endpoint.rfind(":") + base_addr = endpoint[:last_colon_idx] + base_port = int(endpoint[last_colon_idx + 1:]) + new_port = base_port + data_parallel_rank + return f"{base_addr}:{new_port}" + return endpoint + raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'") + + +class EventPublisherFactory: + _registry: dict[str, Callable[..., EventPublisher]] = { + "null": NullEventPublisher, + "zmq": ZmqEventPublisher, + } + + @classmethod + def register_publisher(cls, name: str, + ctor: Callable[..., EventPublisher]) -> None: + if name in cls._registry: + raise KeyError(f"publisher '{name}' already registered") + cls._registry[name] = ctor + + @classmethod + def create(cls, + config: Optional[KVEventsConfig], + data_parallel_rank: int = 0) -> EventPublisher: + """Create publisher from a config mapping.""" + if not config: + return NullEventPublisher() + + config_dict = asdict(config) + + kind = config_dict.pop("publisher", "null") + config_dict.pop("enable_kv_cache_events") + try: + constructor = cls._registry[kind] + except KeyError as exc: + raise ValueError(f"Unknown event publisher '{kind}'") from exc + return constructor(data_parallel_rank=data_parallel_rank, + **config_dict) diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md new file mode 100644 index 0000000..349d3df --- /dev/null +++ b/vllm/distributed/kv_transfer/README.md @@ -0,0 +1,29 @@ + +# Distributed KV cache transfer + +This folder implements distributed KV cache transfer across vLLM instances. +Currently the main usecase is for disaggregated prefilling. + +## Abstractions + +The KV cache transfer contains three layer of abstractions: + +- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`. +- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics). +- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`. + +Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. + +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or +RDMA database). + +NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. + +## Disaggregated prefilling + +The example usage is in [this file](../../../examples/online_serving/disaggregated_prefill.sh). + +Here is the diagram of how we run disaggregated prefilling. + +![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000..fa9b7e4 --- /dev/null +++ b/vllm/distributed/kv_transfer/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.distributed.kv_transfer.kv_transfer_state import ( + KVConnectorBaseType, ensure_kv_transfer_initialized, get_kv_transfer_group, + has_kv_transfer_group, is_v1_kv_transfer_group) + +__all__ = [ + "get_kv_transfer_group", "has_kv_transfer_group", + "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", + "KVConnectorBaseType" +] diff --git a/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg new file mode 100644 index 0000000..a25ec5e Binary files /dev/null and b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg differ diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 0000000..181c339 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +KVConnectorBase Class for Distributed KV Cache & Hidden State communication + +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (list[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (list[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError + + +KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000..be9ce72 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +from typing import TYPE_CHECKING, Callable + +import vllm.envs as envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.logger import init_logger + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class KVConnectorFactory: + _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[KVConnectorBaseType]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector_v0(cls, rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V0 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + connector_name = config.kv_transfer_config.kv_connector + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, KVConnectorBase) + return connector_cls(rank, local_rank, config) + + @classmethod + def create_connector_v1( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase_V1: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}") + + kv_transfer_config = config.kv_transfer_config + connector_name = kv_transfer_config.kv_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = kv_transfer_config.kv_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + assert issubclass(connector_cls, KVConnectorBase_V1) + logger.info("Creating v1 connector with name: %s and engine_id: %s", + connector_name, kv_transfer_config.engine_id) + # NOTE(Kuntai): v1 connector is explicitly separated into two roles. + # Scheduler connector: + # - Co-locate with scheduler process + # - Should only be used inside the Scheduler class + # Worker connector: + # - Co-locate with worker process + # - Should only be used inside the forward context & attention layer + # We build separately to enforce strict separation + return connector_cls(config, role) + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +KVConnectorFactory.register_connector( + "PyNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.simple_connector", + "SimpleConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnector", + "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", + "LMCacheConnector") + +KVConnectorFactory.register_connector( + "MooncakeStoreConnector", + "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", + "MooncakeStoreConnector") + +KVConnectorFactory.register_connector( + "SharedStorageConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", + "SharedStorageConnector") + +KVConnectorFactory.register_connector( + "P2pNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector", + "P2pNcclConnector") + +KVConnectorFactory.register_connector( + "LMCacheConnectorV1", + "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", + "LMCacheConnectorV1") + +KVConnectorFactory.register_connector( + "NixlConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", + "NixlConnector") + +KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py new file mode 100644 index 0000000..78bf309 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +LMCache KV Cache Connector for Distributed Machine Learning Inference + +The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker +(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; +(2) offload and share KV caches. +""" + +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class LMCacheConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.transfer_config = config.kv_transfer_config + self.vllm_config = config + + from lmcache.experimental.cache_engine import LMCacheEngineBuilder + from lmcache.integration.vllm.utils import ENGINE_NAME + from lmcache.integration.vllm.vllm_adapter import ( + RetrieveStatus, StoreStatus, init_lmcache_engine, + lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, + lmcache_store_kv) + logger.info("Initializing LMCacheConfig under kv_transfer_config %s", + self.transfer_config) + + # TODO (Jiayi): Find model_config, parallel_config, and cache_config + self.engine = init_lmcache_engine(config.model_config, + config.parallel_config, + config.cache_config) + self.lmcache_engine_name = ENGINE_NAME + self.lmcache_engine_builder = LMCacheEngineBuilder + + self.model_config = config.model_config + self.parallel_config = config.parallel_config + self.cache_config = config.cache_config + self.lmcache_retrieve_kv = lmcache_retrieve_kv + self.lmcache_store_kv = lmcache_store_kv + self.lmcache_should_retrieve = lmcache_should_retrieve + self.lmcache_should_store = lmcache_should_store + self.store_status = StoreStatus + self.retrieve_status = RetrieveStatus + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + retrieve_status = self.lmcache_should_retrieve(model_input) + model_input, bypass_model_exec, hidden_or_intermediate_states =\ + self.lmcache_retrieve_kv( + model_executable, model_input, self.cache_config, kv_caches, + retrieve_status) + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + store_status = self.lmcache_should_store(model_input) + self.lmcache_store_kv( + self.model_config, + self.parallel_config, + self.cache_config, + model_executable, + model_input, + kv_caches, + store_status, + ) + + def close(self): + self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py new file mode 100644 index 0000000..94a7ce9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +MooncakeStore Connector for Distributed Machine Learning Inference +The MooncakeStoreConnector transfers KV caches between prefill vLLM workers +(KV cache producer) and decode vLLM workers (KV cache consumer) using a +database-style KVStore. +""" +import hashlib +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class MooncakeStoreConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.kv_transfer_config = config.kv_transfer_config + self.kv_helper = kv_helper(config) + self.local_tp_rank = local_rank + + # Init kv_store + if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_store: + raise ValueError( + "To use MooncakeStoreConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 + MooncakeStore) + logger.info( + "Initializing KVStoreConnector under kv_transfer_config %s", + self.kv_transfer_config) + self.kv_store = MooncakeStore(config) + else: + logger.error("Can not find %s", + self.kv_transfer_config.kv_connector) + + assert self.kv_store is not None + + def close(self) -> None: + """Close the buffer and release resources. + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + self.kv_store.close() + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + num_heads, head_size = self.kv_helper.get_model_args(model_executable) + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + current_tokens = input_tokens_tensor[start_pos:end_pos] + store_key_prefix = self.tensor_hash(current_tokens) + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + kvcache_to_sent = torch.stack((keys, values), dim=0) + store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" + self.kv_store.put(store_kvcache_key, kvcache_to_sent) + + hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" + self.kv_store.put(hidden_key, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + bypass_model_exec = True + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + hidden_or_intermediate_states_for_one_req = [] + + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + # get roi for current seq + load_key_prefix = self.tensor_hash(current_tokens) + load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" + remote_kv = self.kv_store.get(load_kvcache_key) + hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" + hidden = self.kv_store.get(hidden_key) + + if remote_kv is None or hidden is None: + # didn't find any match. + bypass_model_exec = False + continue + + num_computed_tokens = current_tokens.shape[0] + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # call self.kv_store to get kv layer by layer + for layer_id in range(start_layer, end_layer): + layer = model_executable.model.layers[layer_id] + # get kvcache object + kv_cache = kv_caches[layer_id - start_layer] + + # get remote kvcache + remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ + layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + @staticmethod + def tensor_hash(tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() + hash_object = hashlib.blake2b(tensor_bytes) + hash_hex = hash_object.hexdigest() + return int(hash_hex[:16], 16) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py new file mode 100644 index 0000000..e7c079e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -0,0 +1,329 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.utils import ( + model_aware_kv_ops_helper as kv_helper) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class SimpleConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.config = config.kv_transfer_config + self.kv_helper = kv_helper(config) + + if self.config.kv_connector == "PyNcclConnector": + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + logger.info( + "Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + elif self.config.kv_connector == "MooncakeConnector": + # Check if MOONCAKE_CONFIG_PATH is set + import os + use_mooncake_distributed_pipe = os.getenv( + 'MOONCAKE_CONFIG_PATH') is not None + + if not use_mooncake_distributed_pipe: + raise ValueError( + "To use MooncakeConnector, you need to pass the ENV: " + "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") + else: + from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501 + MooncakePipe) + logger.info( + "Initializing MooncakeConfig under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None + + self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe] + self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe] + self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if self.config.is_kv_producer: + + if self.config.kv_connector == "PyNcclConnector": + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + elif self.config.kv_connector == "MooncakeConnector": + self.producer_data_pipe = MooncakePipe( + local_rank=local_rank, + config=self.config, + ) + # We only need to initialize MooncakePipe once + self.producer_signal_pipe = self.producer_data_pipe + + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producer + if self.config.kv_connector == "PyNcclConnector": + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + elif self.config.kv_connector == "MooncakeConnector": + self.consumer_data_pipe = MooncakePipe( + local_rank=local_rank, + config=self.config, + ) + self.consumer_signal_pipe = self.consumer_data_pipe + + self.consumer_buffer = SimpleBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + self.config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + num_heads, head_size = self.kv_helper.get_model_args(model_executable) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You have some decode requests while using " + "SimpleConnector. Their KVCache won't be sent.") + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + key_cache, value_cache = self.kv_helper.get_kv_from_cache( + kv_cache, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to --max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for cur_layer in range(start_layer, end_layer): + + layer_id = cur_layer - start_layer + kv_cache = kv_caches[layer_id] + layer = model_executable.model.layers[cur_layer] + + # get remote kvcache + remote_k, remote_v = keys[layer_id], values[layer_id] + + self.kv_helper.put_kv_to_cache(model_executable, remote_k, + remote_v, layer, kv_cache, + slot_mapping, start_pos, + end_pos) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.warning( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.producer_data_pipe.close() + self.consumer_data_pipe.close() + if self.config.kv_connector == "PyNcclConnector": + self.producer_signal_pipe.close() + self.consumer_signal_pipe.close() + elif self.config.kv_connector == "MooncakeConnector": + # MooncakePipe reuses data_pipe for signal_pipe, so we only have to + # close the data_pipe. + pass diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py new file mode 100644 index 0000000..5cbc8ca --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +KV cache helper for store. +""" +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class model_aware_kv_ops_helper: + + def __init__(self, config: VllmConfig): + self.is_deepseek_mla = config.model_config.is_deepseek_mla + self.use_mla_opt = not envs.VLLM_MLA_DISABLE + self.tp_size = config.parallel_config.tensor_parallel_size + + def get_model_args(self, model_executable: torch.nn.Module): + + model_config = model_executable.model.config + self.model_executable = model_executable + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + + # Deepseek's MLA (Multi-head Latent Attention) uses two different + # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. + # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, + # resulting in a kv_cache shape of [num_blks, blk_size, 1, + # kv_lora_rank + qk_rope_head_dim]. + # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading + # to a kv_cache shape of [2, num_blks, blk_size, + # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. + # For more details, see vllm/attention/backends/mla/common.py. + if self.is_deepseek_mla and self.use_mla_opt: + head_size = model_config.kv_lora_rank + \ + model_config.qk_rope_head_dim + num_heads = 1 + elif self.is_deepseek_mla and not self.use_mla_opt: + head_size = model_config.qk_nope_head_dim + \ + model_config.qk_rope_head_dim + else: + head_size = getattr(model_config, "head_dim", None) + if head_size is None: + head_size = int(hidden_size // num_attention_heads) + + return num_heads, head_size + + def get_kv_from_cache(self, kv_cache, num_heads, head_size): + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + return key_cache, value_cache + + def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, + layer, kv_cache, slot_mapping, start_pos, end_pos): + + model_config = model_executable.model.config + + if self.is_deepseek_mla and self.use_mla_opt: + layer.self_attn.attn = layer.self_attn.mla_attn + k_c_normed_k_pe = keys.squeeze(1) + k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + ops.concat_and_cache_mla( + k_c_normed.to(kv_cache.device), + k_pe.to(kv_cache.device), + kv_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + ) + else: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys.to(key_cache.device), + values.to(value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + +def get_kv_connector_cache_layout(): + # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is + # used for faster transfer. + vllm_config = get_current_vllm_config() + kv_config = vllm_config.kv_transfer_config + if kv_config is not None and vllm_config.model_config is None: + logger.warning_once("Unable to detect current VLLM config. " \ + "Defaulting to NHD kv cache layout.") + elif kv_config is not None: + use_mla = vllm_config.model_config.use_mla + if not use_mla and kv_config.kv_connector == "NixlConnector": + logger.info_once("NixlConnector detected. Setting KV cache " \ + "layout to HND for better xfer performance.") + return "HND" + return "NHD" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py new file mode 100644 index 0000000..f00f31d --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorRole) + +__all__ = ["KVConnectorRole", "KVConnectorBase_V1"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py new file mode 100644 index 0000000..f80b5eb --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State +communication in vLLM v1 + +The class provides the following primitives: + Scheduler-side: runs in the scheduler, binds metadata, which + is used by the worker-side to load/save KV cache. + get_num_new_matched_tokens() - get number of new tokens + that exist in the remote KV cache. Might be called multiple + times for a given request and should be side-effect free. + update_state_after_alloc() - update KVConnector state after + temporary buffer alloc by the CacheManager. + request_finished() - called when a request is finished, with + the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or will be + freed asynchronously and optionally returns KV transfer + params. + + Worker-side: runs in each worker, loads/saves KV cache to/from + the Connector based on the metadata. + start_load_kv() - starts loading all KVs (maybe async) + wait_for_layer_load() - blocks until layer i load is done + + save_kv_layer() - starts saving KV for layer i (maybe async) + wait_for_save() - blocks until all saves are done + + get_finished() - called with ids of finished requests, returns + ids of requests that have completed async sending/recving. +""" + +import enum +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class KVConnectorRole(enum.Enum): + # Connector running in the scheduler process + SCHEDULER = 0 + + # Connector running in the worker process + WORKER = 1 + + +class KVConnectorMetadata: + """ + Abstract Metadata used to communicate between the + Scheduler KVConnector and Worker KVConnector. + """ + pass + + +class KVConnectorBase_V1(ABC): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + logger.warning( + "Initializing KVConnectorBase_V1. This API is experimental and " + "subject to change in the future as we iterate the design.") + self._connector_metadata = KVConnectorMetadata() + self._vllm_config = vllm_config + self._role = role + + @property + def role(self) -> KVConnectorRole: + return self._role + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self._connector_metadata = connector_metadata + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self._connector_metadata = KVConnectorMetadata() + + def _get_connector_metadata(self) -> KVConnectorMetadata: + """Get the connector metadata. + + This function should only be called inside the connector. + + Returns: + ConnectorMetadata: the connector metadata. + """ + return self._connector_metadata + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + return + + @abstractmethod + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + pass + + @abstractmethod + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + pass + + @abstractmethod + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + pass + + @abstractmethod + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return None, None + + # ============================== + # Scheduler-side methods + # ============================== + + @abstractmethod + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + A tuple with the following elements: + - The number of tokens that can be loaded from the + external KV cache beyond what is already computed. + - `True` if external KV cache tokens will be loaded + asynchronously (between scheduler steps). Must be + 'False' if the first element is 0. + """ + pass + + @abstractmethod + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If get_num_new_matched_tokens previously returned True for a + request, this function may be called twice for that same request - + first when blocks are allocated for the connector tokens to be + asynchronously loaded into, and second when any additional blocks + are allocated, after the load/transfer is complete. + + Args: + request (Request): the request object. + blocks (KVCacheBlocks): the blocks allocated for the request. + num_external_tokens (int): the number of tokens that will be + loaded from the external KV cache. + """ + pass + + @abstractmethod + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + pass + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return False, None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py new file mode 100644 index 0000000..e838ac2 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any, Optional + +import torch +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class LMCacheConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self._lmcache_engine.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self._lmcache_engine.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, + **kwargs) + + def wait_for_save(self): + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self._lmcache_engine.wait_for_save() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + return self._lmcache_engine.get_finished(finished_req_ids) + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self._lmcache_engine.get_num_new_matched_tokens( + request, num_computed_tokens), False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + self._lmcache_engine.update_state_after_alloc(request, + num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self._lmcache_engine.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + return self._lmcache_engine.request_finished(request, block_ids) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py new file mode 100644 index 0000000..be3c233 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MultiKVConnectorMetadata(KVConnectorMetadata): + metadata: tuple[KVConnectorMetadata, ...] + extra_async_saves: Optional[dict[str, int]] = None + + +class MultiConnector(KVConnectorBase_V1): + """ + A wrapper for using multiple KVConnectors at the same time. + + The current logic is: + - Load KV from the first connector that advertises available tokens from + get_num_new_matched_tokens(), based on the order in the config. + - Save to all connectors. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._connectors: list[KVConnectorBase_V1] = [] + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + for ktc in ktcs: + temp_config = copy.copy(vllm_config) + temp_config.kv_transfer_config = KVTransferConfig(**ktc) + self._connectors.append( + KVConnectorFactory.create_connector_v1(temp_config, role)) + + # A mapping from request id to the index of the connector chosen to + # load the request from (if any). + self._requests_to_connector: dict[str, int] = {} + + # Keeps track of *additional* remaining async saves (beyond 1) to be + # finished per request. Not needed for async loads since we only allow + # a single connector to load. + # Propagated from scheduler to worker side via the connector metadata. + self._extra_async_saves: dict[str, int] = {} + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for c in self._connectors: + c.register_kv_caches(kv_caches) + + # We must override the base class method here because we need to bind + # the metadata to each connector in the order of the connectors in the + # MultiKVConnectorMetadata. + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, MultiKVConnectorMetadata) + if connector_metadata.extra_async_saves: + self._extra_async_saves.update( + connector_metadata.extra_async_saves) + for c, cm in zip(self._connectors, connector_metadata.metadata): + c.bind_connector_metadata(cm) + + def clear_connector_metadata(self) -> None: + for c in self._connectors: + c.clear_connector_metadata() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + for c in self._connectors: + c.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + for c in self._connectors: + c.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + for c in self._connectors: + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): + for c in self._connectors: + c.wait_for_save() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + finished_sending: set[str] = set() + finished_recving: set[str] = set() + for c in self._connectors: + sending, recving = c.get_finished(finished_req_ids) + if not recving and not sending: + continue + # Aggregate finished recving request ids. + finished_recving.update(recving or ()) + # Aggregate finished sending request ids - only include + # once we've drained the "extra" count (for cases where + # more than one connector is async-saving the same request). + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_sending or None, finished_recving or None + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + to_return = (0, False) + for i, c in enumerate(self._connectors): + toks, load_async = c.get_num_new_matched_tokens( + request, num_computed_tokens) + # The first connector that has new matched tokens will be assigned + # to this request. + if to_return[0] == 0 and toks > 0: + self._requests_to_connector[request.request_id] = i + to_return = (toks, load_async) + return to_return + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + chosen_connector = self._requests_to_connector.get( + request.request_id, -1) + empty_blocks = blocks.new_empty() + for i, c in enumerate(self._connectors): + if i == chosen_connector: + # Forward call to the chosen connector (if any). + c.update_state_after_alloc(request, blocks, + num_external_tokens) + else: + # Call with empty blocks for other connectors. + c.update_state_after_alloc(request, empty_blocks, 0) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: + metadata = MultiKVConnectorMetadata(metadata=tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) + if self._extra_async_saves: + metadata.extra_async_saves = self._extra_async_saves + self._extra_async_saves = {} + return metadata + + def request_finished( + self, + request: "Request", + blocks: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + async_saves = 0 + kv_txfer_params = None + for c in self._connectors: + async_save, txfer_params = c.request_finished(request, blocks) + if async_save: + async_saves += 1 + if txfer_params is not None: + if kv_txfer_params is not None: + #TODO we can probably change this to merge the dicts here, + # checking for key clashes. + raise RuntimeError( + "Only one connector can produce KV transfer params") + kv_txfer_params = txfer_params + if async_saves > 1: + self._extra_async_saves[request.request_id] = async_saves - 1 + + # Clean up other state for this request. + self._requests_to_connector.pop(request.request_id, None) + + return async_saves > 0, kv_txfer_params diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py new file mode 100644 index 0000000..56ae1ac --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -0,0 +1,1103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import math +import queue +import threading +import time +import uuid +from collections import defaultdict +from collections.abc import Iterator +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.distributed.utils import divide +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.platforms import _Backend +from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +Transfer = tuple[int, float] # (xfer_handle, start_time) +EngineId = str +ReqId = str +GET_META_MSG = b"get_meta_msg" + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + + +class NixlAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + tp_size: int + + +class NixlConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[ReqId, ReqMeta] = {} + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + # P workers don't need to receive tp_size from proxy here. + tp_size=kv_transfer_params.get("tp_size", 1), + ) + + +class NixlConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[NixlConnectorScheduler] = \ + NixlConnectorScheduler(vllm_config, self.engine_id) + self.connector_worker: Optional[NixlConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = NixlConnectorWorker( + vllm_config, self.engine_id) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """NixlConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """NixlConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """NixlConnector does not save explicitly.""" + pass + + +class NixlConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id: EngineId = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank * + vllm_config.parallel_config.tensor_parallel_size) + logger.info("Initializing NIXL Scheduler %s", engine_id) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + rounded_num_prompt_tokens = round_down( + len(request.prompt_token_ids), self.block_size) + count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + if count > 0: + return count, True + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = NixlConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "NIXLConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + if not params: + return False, None + + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if (not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + # Get computed blocks. + all_full = request.num_computed_tokens % self.block_size == 0 + computed_block_ids = block_ids if all_full else block_ids[:-1] + + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + + +class NixlConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + logger.info("Initializing NIXL worker %s", engine_id) + + # Config. + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # Agent. + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) + + # NIXL handshake port. + # NOTE(rob): Within a DP group, each DP rank gets its own + # base port (which is sent in the KVTransferParams). + # Each TP rank listens/queries on the base_port + tp_rank. + self.side_channel_port: int = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank * + vllm_config.parallel_config.tensor_parallel_size) + + # Metadata. + self.engine_id: EngineId = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and nixl tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + + # Number of NIXL regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + self.num_layers = 0 + + # nixl_prepped_dlist_handle. + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[EngineId, int] = {} + + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[EngineId, int] = {} + self._registered_descs: list[Any] = [] + + # In progress transfers. + # [req_id -> list[handle]] + self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) + + # Complete transfer tracker. Used by the rank 0 to track finished + # transactions on ranks 1 to N-1. + # [req_id -> count] + self._done_recving_count: defaultdict[ReqId, + int] = defaultdict(lambda: 0) + self._done_sending_count: defaultdict[ReqId, + int] = defaultdict(lambda: 0) + + # Background thread for handling new handshake requests. + self._nixl_handshake_listener_t: Optional[threading.Thread] = None + # Background thread for initializing new NIXL handshakes. + self._handshake_initiation_executor = ThreadPoolExecutor( + # NIXL is not guaranteed to be thread-safe, limit 1 worker. + max_workers=1, + thread_name_prefix="vllm-nixl-handshake-initiator") + self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() + self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} + # Protects _handshake_futures and _remote_agents. + self._handshake_lock = threading.RLock() + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + # List of block window sizes for each layer for local attention + self.block_window_per_layer: list[Optional[int]] = [] + self.use_mla = self.model_config.use_mla + + backend = get_attn_backend(self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + use_mla=self.use_mla) + self.backend_name = backend.get_name() + attn_backend = backend_name_to_enum(self.backend_name) + self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 + logger.debug("Detected attention backend %s", self.backend_name) + + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + # With heterogeneous TP, P must wait for all assigned D TP workers to + # finish reading before safely freeing the blocks. + self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + + def __del__(self): + """Cleanup background threads on destruction.""" + self._handshake_initiation_executor.shutdown(wait=False) + if self._nixl_handshake_listener_t: + self._nixl_handshake_listener_t.join(timeout=0) + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, base_port: int, + tp_rank: int): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach via HTTP endpoint soon. + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, base_port + tp_rank) + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + def _nixl_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: + """Do a NIXL handshake with a remote instance.""" + + start_time = time.perf_counter() + + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + + def handshake(path: str, rank: int) -> str: + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + remote_agent_name = self.add_remote_agent( + metadata, rank, remote_tp_size) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return remote_agent_name + + # Handshake only with the remote TP rank that current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", path, + p_remote_rank) + # Remote rank -> agent name. + return {p_remote_rank: handshake(path, p_remote_rank)} + + def _background_nixl_handshake(self, req_id: str, + remote_engine_id: EngineId, meta: ReqMeta): + # Do NIXL handshake in background and add to _ready_requests when done. + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + fut = self._handshake_initiation_executor.submit( + self._nixl_handshake, meta.remote_host, meta.remote_port, + meta.tp_size) + self._handshake_futures[remote_engine_id] = fut + + def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): + with self._handshake_lock: + del self._handshake_futures[eid] + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("Handshake with %s failed", eid) + + fut.add_done_callback(done_callback) + + # TODO: handle failure state of future in the + # callback, we want to fail the request in this case. + def request_ready(_f: Future[Any], entry=(req_id, meta)): + self._ready_requests.put(entry) + + fut.add_done_callback(request_ready) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in nixl.""" + + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + # TODO(tms): Find a more robust way to detect and handle MLA + # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected + # KV memory layout is HND, as opposed to the default NHD. Note that it + # will only affects the strides. For MLA instead, we make require no + # such thing and resort to the standard layout. + use_mla = len(first_kv_cache.shape) == 3 + assert use_mla == self.use_mla + + # TODO (NickLucche) not compatible with hybrid allocator. Enforce check + # once it goes live, as a single kv layout is expected for xfers. + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim + else: + # [2 (k and v), num_blocks, ...] + if self._use_flashinfer: + # FlashInfer swaps 2<->num_blocks dimensions. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 4 # [2, block_size, kv_heads, head_dim] + else: + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, n_kv_heads, head_dim = block_shape[-3:] + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + # block size in bytes + self.block_len = kv_elem_size * math.prod(block_shape) + logger.info( + "Registering KV_Caches: use_mla: %s, num_blocks: %s, " + "block_shape: %s, per_layer_kv_cache_shape: %s", use_mla, + self.num_blocks, block_shape, first_kv_cache.shape) + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + + # Note(tms): I modified this from the original region setup code. + # K and V are now in different regions. Advantage is that we can + # elegantly support MLA and any cases where the K and V tensors + # are non-contiguous (it's not locally guaranteed that they will be) + # Disadvantage is that the encoded NixlAgentMetadata is now larger + # (roughly 8KB vs 5KB). + # Conversely for FlashInfer, K and V are transferred in the same tensor + # to better exploit the memory layout (ie num_blocks is the first dim). + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \ + else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append( + (base_addr, region_len, cache.device.index, "")) + kv_caches_base_addr.append(base_addr) + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + logger.debug("Done registering descs") + self._registered_descs.append(descs) + + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.tp_rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + # After KV Caches registered, listen for new connections. + metadata = NixlAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.nixl_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + block_len=self.block_len, + attn_backend_name=self.backend_name) + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="nixl_handshake_listener") + self._nixl_handshake_listener_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_tp_rank: int = 0, + remote_tp_size: int = 1) -> str: + """ + Add the remote NIXL agent and prepare the descriptors for reading cache + blocks from remote. + + In particular, handle both homogeneous and heterogeneous TP. The former + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or + more local TP worker share the xfer from a single TP worker. + + Here's an example: + + rank_offset p_remote_tp_rank + (kv split no) + -------------------------------- + 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] + / + 1 0 Worker1 ---- 2nd half of KV -----/ + + 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] + / + 1 1 Worker3 ---- 2nd half of KV -----/ + + + Decoder TP workers Prefix TP workers + (world_size=4) (world_size=2) + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. + + Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 + so that the whole cache is shared by "tp_ratio" D TP workers. + """ # noqa: E501 + engine_id = nixl_agent_meta.engine_id + # TODO re-evaluate refreshing for scaling/recovery + if remote_tp_rank in self._remote_agents.get(engine_id, {}): + return self._remote_agents[engine_id][remote_tp_rank] + + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == remote_tp_size + else: + self._tp_size[engine_id] = remote_tp_size + # We may eventually enable this after asserting equality in cache + # layout and close outputs. + assert nixl_agent_meta.attn_backend_name == self.backend_name + + remote_agent_name = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = divide(self._tp_size[self.engine_id], + self._tp_size[engine_id]) + assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + + # Handle tp_size>num_kv_heads: replicate KV cache. + total_num_kv_heads = self.model_config.get_total_num_kv_heads() + is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + + if self.use_mla or is_kv_replicated: + # With MLA the only difference is in the number of blocks. + remote_block_size = nixl_agent_meta.block_len // ( + self.slot_size_bytes) + assert self.block_len == nixl_agent_meta.block_len + else: + remote_block_size = nixl_agent_meta.block_len // ( + self.slot_size_bytes * tp_ratio) + if self._use_flashinfer: + # Account for joint KV in FlashInfer. + remote_block_size //= 2 + + assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( + "Remote P worker KV layer cache must be of shape [2, N, " + "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." + ) + + assert self.block_size == remote_block_size, ( + "Remote P worker with different block size is not supported " + f"{self.block_size=} {remote_block_size=}") + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks + else: + self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + + blocks_data = [] + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + # Only register the remote's descriptors if current rank pulls from it. + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.tp_rank % tp_ratio * self.block_len \ + if not (self.use_mla or is_kv_replicated) else 0 + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + # For each block, grab the heads chunk belonging to rank_i + # of size remote_nheads // tp_ratio, which correspond to + # self.block_len == remote_block_len//tp_ratio bytes. + addr = base_addr + block_offset + rank_offset + # (addr, len, device id) + blocks_data.append((addr, self.block_len, remote_tp_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " + "local rank %s", len(blocks_data), engine_id, remote_tp_rank, + self.tp_rank) + + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs) + + return remote_agent_name + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving. + + In TP>1 setup, each rank exchanges KVs with its counterpart + ranks independently. get_finished() runs in a worker creates + the done_sending and done_recving sets that are sent to the + scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs + are done before adding to finished, Ranks 1 to N-1 communicate + to Rank 0 once their transaction is done + Rank 0 returns + finished sets to Scheduler only once all ranks are done. + """ + done_sending = self._get_new_notifs() + done_recving = self._pop_done_transfers(self._recving_transfers) + if len(done_sending) > 0 or len(done_recving) > 0: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", self.tp_rank, + len(done_sending), len(done_recving)) + + if self.world_size == 1: + return done_sending, done_recving + + # Rank 0: get finished from all other ranks. + if self.tp_rank == 0: + for req_id in done_sending: + self._done_sending_count[req_id] += 1 + for req_id in done_recving: + self._done_recving_count[req_id] += 1 + + # Keep track of how many other ranks have finished. + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + # Return ids that finished on all ranks to the scheduler. + all_done_recving: set[str] = set() + for req_id in list(self._done_recving_count.keys()): + if self._done_recving_count[req_id] == self.world_size: + del self._done_recving_count[req_id] + all_done_recving.add(req_id) + + all_done_sending: set[str] = set() + for req_id in list(self._done_sending_count.keys()): + if self._done_sending_count[req_id] == self.world_size: + del self._done_sending_count[req_id] + all_done_sending.add(req_id) + + return all_done_sending, all_done_recving + + # Ranks 1 to N-1: send finished ids to Rank 0. + else: + finished_req_ids = list(done_recving.union(done_sending)) + self.tp_group.send_object(finished_req_ids, dst=0) + + # Unused as only Rank 0 results are sent to scheduler. + return done_sending, done_recving + + def _get_new_notifs(self) -> set[str]: + """ + Get req_ids which got a remote xfer message. When multiple consumers + are reading from the same producer (heterogeneous TP scenario), wait + for all consumers to be done pulling. + """ + notified_req_ids: set[str] = set() + for notifs in self.nixl_wrapper.get_new_notifs().values(): + for notif in notifs: + req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + self.consumer_notification_counts_by_req[req_id] += 1 + # Wait all consumers (D) to be done reading before freeing. + if self.consumer_notification_counts_by_req[req_id] == int( + tp_ratio): + notified_req_ids.add(req_id) + del self.consumer_notification_counts_by_req[req_id] + return notified_req_ids + + def _pop_done_transfers( + self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: + """ + Pop completed xfers by checking for DONE state. + Args: + transfers: dict of req_id -> list[running_xfer] + Returns: + set of req_ids that have all done xfers + """ + done_req_ids: set[str] = set() + for req_id, handles in list(transfers.items()): + in_progress = False + for handle, _xfer_stime in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + self.nixl_wrapper.release_xfer_handle(handle) + elif xfer_state == "PROC": + in_progress = True + continue + else: + raise RuntimeError("Transfer failed with state %s", + xfer_state) + if not in_progress: + done_req_ids.add(req_id) + del transfers[req_id] + return done_req_ids + + def start_load_kv(self, metadata: NixlConnectorMetadata): + """ + Start loading by triggering non-blocking nixl_xfer. + We check for these trnxs to complete in each step(). + """ + for req_id, meta in metadata.requests.items(): + remote_engine_id = meta.remote_engine_id + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + if remote_engine_id not in self._remote_agents: + # Initiate handshake with remote engine to exchange metadata. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + self._background_nixl_handshake( + req_id, remote_engine_id, meta) + continue + + # Handshake already completed, start async read xfer. + self._read_blocks_for_req(req_id, meta) + + # Start transfers for requests whose handshakes have now finished. + while not self._ready_requests.empty(): + self._read_blocks_for_req(*self._ready_requests.get_nowait()) + + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, req_id) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + ) + + def _read_blocks(self, local_block_ids: list[int], + remote_block_ids: list[int], dst_engine_id: str, + request_id: str): + # NOTE(rob): having the staging blocks be on the READER side is + # not going to work well (since we will have to call rearrange tensors). + # after we detect the txn is complete (which means we cannot make the + # read trxn async easily). If we want to make "READ" happen cleanly, + # then we will need to have the staging blocks on the remote side. + + # NOTE(rob): according to nvidia the staging blocks are used to + # saturate IB with heterogeneous TP sizes. We should remove the staging + # blocks until we are ready. + + # Number of D TP workers that will read from dst P. Propagate tp_ratio + # on notification so that dst worker can wait before freeing blocks. + tp_ratio = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + notif_id = f"{request_id}:{tp_ratio}".encode() + + # Full prefix cache hit: do not need to read remote blocks, + # just notify P worker that we have the blocks we need. + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + remote_rank = self.tp_rank // tp_ratio + agent_name = self._remote_agents[dst_engine_id][remote_rank] + self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) + return + + # Partial prefix cache hit: just read uncomputed blocks. + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + + # Get side handles. + local_xfer_side_handle = self.src_xfer_side_handle + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. + + # Get descs ids. + local_block_descs_ids: list[int] = [] + remote_block_descs_ids: list[int] = [] + if not self.block_window_per_layer: + # Default case: assume global attention + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids) + else: + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + for layer_idx, block_window in enumerate( + self.block_window_per_layer): + # For each layer: + if block_window is None: + # If not chunked, we just use the + # full block lists (global attention) + layer_local_block_ids = local_block_ids + layer_remote_block_ids = remote_block_ids + else: + # If chunked, get the last block_window blocks + layer_local_block_ids = local_block_ids[-block_window:] + layer_remote_block_ids = remote_block_ids[-block_window:] + + # Get descs ids for the layer. + layer_local_desc_ids = self._get_block_descs_ids( + self.engine_id, layer_local_block_ids, layer_idx) + layer_remote_desc_ids = self._get_block_descs_ids( + dst_engine_id, layer_remote_block_ids, layer_idx) + + local_block_descs_ids.extend(layer_local_desc_ids) + remote_block_descs_ids.extend(layer_remote_desc_ids) + + assert len(local_block_descs_ids) == len(remote_block_descs_ids) + + # Prepare transfer with Nixl. + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=notif_id, + ) + + # Begin async xfer. + self.nixl_wrapper.transfer(handle) + + # Use handle to check completion in future step(). + # TODO (NickLucche) surface xfer elapsed time + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) + + def _get_block_descs_ids(self, + engine_id: str, + block_ids: list[int], + layer_idx: Optional[int] = None) -> list[int]: + """ + Get the descs ids for a set of block ids. + If layer_idx is provided, we use the region_ids for the given layer. + Otherwise, we use all regions. + """ + if layer_idx is None: + region_ids = range(self.num_regions) + else: + assert layer_idx < self.num_layers + if self.num_layers < self.num_regions: + # If we have more regions than layers, we assume that + # the regions are organized as [K0, V0, K1, V1, ...] + # and we select K_i and V_i + assert 2 * self.num_layers == self.num_regions + region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + else: + # Otherwise, we assume we have MLA and select i-th layer + assert self.num_layers == self.num_regions + region_ids = range(layer_idx, layer_idx + 1) + + num_blocks = self.dst_num_blocks[engine_id] + + # Compute the desc ids for each block. + descs_ids: list[int] = [] + for reg_id in region_ids: + for block_id in block_ids: + descs_ids.append(reg_id * num_blocks + block_id) + return descs_ids + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py new file mode 100644 index 0000000..795ba35 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import regex as re +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( + P2pNcclEngine) +from vllm.distributed.parallel_state import get_world_group +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request Id + request_id: str + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + + @staticmethod + def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], + block_size: int) -> "ReqMeta": + valid_num_tokens = len(token_ids) + token_ids_tensor = torch.tensor(token_ids) + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + + return ReqMeta( + request_id=request_id, + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + ) + + +@dataclass +class P2pNcclConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + request_id: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + ) -> None: + self.requests.append( + ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)) + + +class P2pNcclConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Any] = {} + self.config = vllm_config.kv_transfer_config + self.is_producer = self.config.is_kv_producer + self.chunked_prefill: dict[str, Any] = {} + + self._rank = get_world_group().rank \ + if role == KVConnectorRole.WORKER else 0 + self._local_rank = get_world_group().local_rank \ + if role == KVConnectorRole.WORKER else 0 + + self.p2p_nccl_engine = P2pNcclEngine( + local_rank=self._local_rank, + config=self.config, + hostname="", + port_offset=self._rank, + ) if role == KVConnectorRole.WORKER else None + + # ============================== + # Worker-side methods + # ============================== + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + + # Only consumer/decode loads KV Cache + if self.is_producer: + return + + assert self.p2p_nccl_engine is not None + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + request_id: str, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + request_id (str): request id for log + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, + 0) + num_token = src_kv_cache.shape[0] + if len(slot_mapping) == num_token: + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + else: + dst_kv_cache_layer[slot_mapping[:num_token], + ...] = src_kv_cache + logger.warning( + "🚧src_kv_cache does not match, num_slot:%d, " + "num_token:%d, request_id:%s", len(slot_mapping), + num_token, request_id) + + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache, + 1) + num_token = src_kv_cache.shape[1] + if len(slot_mapping) == num_token: + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + else: + dst_kv_cache_layer[:, slot_mapping[:num_token], + ...] = src_kv_cache + logger.warning( + "🚧src_kv_cache does not match, num_slot:%d, " + "num_token:%d, request_id:%s", len(slot_mapping), + num_token, request_id) + + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = \ + self._get_connector_metadata() + assert isinstance(metadata, P2pNcclConnectorMetadata) + + if metadata is None: + return + + # Load the KV for each request each layer + for request in metadata.requests: + for layer_name in forward_context.no_compile_layers: + layer = forward_context.no_compile_layers[layer_name] + + # Only process layers that have kv_cache + # attribute (attention layers) Skip non-attention + # layers like FusedMoE + kv_cache = getattr(layer, 'kv_cache', None) + if kv_cache is None: + continue + + kv_cache_layer = kv_cache[ \ + forward_context.virtual_engine] + + kv_cache = self.p2p_nccl_engine.recv_tensor( + request.request_id + "#" + layer_name) + + if kv_cache is None: + logger.warning("🚧src_kv_cache is None, %s", + request.request_id) + continue + + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping, request.request_id) + tensor_id = request.request_id + "#" + layer_name + if tensor_id in self.p2p_nccl_engine.recv_store: + tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None) + self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop( + request.request_id, None) + self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop( + request.request_id, None) + addr = 0 + if isinstance(tensor, tuple): + addr, _, _ = tensor + self.p2p_nccl_engine.pool.free(addr) + + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + # Only producer/prefill saves KV Cache + if not self.is_producer: + return + + assert self.p2p_nccl_engine is not None + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, P2pNcclConnectorMetadata) + for request in connector_metadata.requests: + request_id = request.request_id + ip, port = self.parse_request_id(request_id, True) + remote_address = ip + ":" + str(port + self._rank) + kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) + self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, + kv_cache, remote_address) + + def wait_for_save(self): + if self.is_producer: + assert self.p2p_nccl_engine is not None + self.p2p_nccl_engine.wait_for_sent() + + def get_finished( + self, finished_req_ids: set[str], + **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + + assert self.p2p_nccl_engine is not None + + forward_context: ForwardContext = get_forward_context() + return self.p2p_nccl_engine.get_finished(finished_req_ids, + forward_context) + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + if self.is_producer: + return 0, False + + num_external_tokens = (len(request.prompt_token_ids) - 1 - + num_computed_tokens) + + if num_external_tokens < 0: + num_external_tokens = 0 + + return num_external_tokens, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + """ + if not self.is_producer and num_external_tokens > 0: + self._requests_need_load[request.request_id] = ( + request, blocks.get_block_ids()[0]) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + meta = P2pNcclConnectorMetadata() + + for new_req in scheduler_output.scheduled_new_reqs: + if self.is_producer: + num_scheduled_tokens = ( + scheduler_output.num_scheduled_tokens)[new_req.req_id] + num_tokens = num_scheduled_tokens + new_req.num_computed_tokens + # the request's prompt is chunked prefill + if num_tokens < len(new_req.prompt_token_ids): + # 'CachedRequestData' has no attribute 'prompt_token_ids' + self.chunked_prefill[new_req.req_id] = ( + new_req.block_ids[0], new_req.prompt_token_ids) + continue + # the request's prompt is not chunked prefill + meta.add_request(request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size) + continue + if new_req.req_id in self._requests_need_load: + meta.add_request(request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size) + self._requests_need_load.pop(new_req.req_id) + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + + if self.is_producer: + num_scheduled_tokens = ( + scheduler_output.num_scheduled_tokens)[req_id] + num_tokens = (num_scheduled_tokens + num_computed_tokens) + assert req_id in self.chunked_prefill + block_ids = new_block_ids[0] + if not resumed_from_preemption: + block_ids = (self.chunked_prefill[req_id][0] + block_ids) + prompt_token_ids = self.chunked_prefill[req_id][1] + # the request's prompt is chunked prefill again + if num_tokens < len(prompt_token_ids): + self.chunked_prefill[req_id] = (block_ids, + prompt_token_ids) + continue + # the request's prompt is all prefilled finally + meta.add_request(request_id=req_id, + token_ids=prompt_token_ids, + block_ids=block_ids, + block_size=self._block_size) + self.chunked_prefill.pop(req_id, None) + continue + + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not resumed_from_preemption: + break + if req_id in self._requests_need_load: + request, _ = self._requests_need_load.pop(req_id) + total_tokens = num_computed_tokens + 1 + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = new_block_ids[0] + + meta.add_request(request_id=req_id, + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size) + + # Requests loaded asynchronously are not in the scheduler_output. + # for request_id in self._requests_need_load: + # request, block_ids = self._requests_need_load[request_id] + # meta.add_request(request_id=request.request_id, + # token_ids=request.prompt_token_ids, + # block_ids=block_ids, + # block_size=self._block_size) + + self._requests_need_load.clear() + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + + self.chunked_prefill.pop(request.request_id, None) + + return False, None + + # ============================== + # Static methods + # ============================== + + @staticmethod + def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]: + # Regular expression to match the string hostname and integer port + if is_prefill: + pattern = r"___decode_addr_(.*):(\d+)" + else: + pattern = r"___prefill_addr_(.*):(\d+)___" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + ip = match.group(1) + port = int(match.group(2)) + + return ip, port + raise ValueError( + f"Request id {request_id} does not contain hostname and port") + + @staticmethod + def check_tensors_except_dim(tensor1, tensor2, dim): + shape1 = tensor1.size() + shape2 = tensor2.size() + + if len(shape1) != len(shape2) or not all( + s1 == s2 + for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim): + raise NotImplementedError( + "Currently, only symmetric TP is supported. Asymmetric TP, PP," + "and others will be supported in future PRs.") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py new file mode 100644 index 0000000..c6f8aed --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -0,0 +1,529 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import os +import threading +import time +import typing +from collections import deque +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Optional + +import msgpack +import torch +import zmq + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) +from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 + TensorMemoryPool) +from vllm.utils import current_stream, get_ip + +if TYPE_CHECKING: + from vllm.forward_context import ForwardContext + +logger = logging.getLogger(__name__) + +DEFAULT_MEM_POOL_SIZE_GB = 32 + + +@contextmanager +def set_p2p_nccl_context(num_channels: str): + original_values: dict[str, Any] = {} + env_vars = [ + 'NCCL_MAX_NCHANNELS', + 'NCCL_MIN_NCHANNELS', + 'NCCL_CUMEM_ENABLE', + 'NCCL_BUFFSIZE', + 'NCCL_PROTO', # LL,LL128,SIMPLE + 'NCCL_ALGO', # RING,TREE + ] + + for var in env_vars: + original_values[var] = os.environ.get(var) + + logger.info("set_p2p_nccl_context, original_values: %s", original_values) + + try: + os.environ['NCCL_MAX_NCHANNELS'] = num_channels + os.environ['NCCL_MIN_NCHANNELS'] = num_channels + os.environ['NCCL_CUMEM_ENABLE'] = '1' + yield + finally: + for var in env_vars: + if original_values[var] is not None: + os.environ[var] = original_values[var] + else: + os.environ.pop(var, None) + + +class P2pNcclEngine: + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None) -> None: + self.config = config + self.rank = port_offset + self.local_rank = local_rank + self.device = torch.device(f"cuda:{self.local_rank}") + self.nccl = NCCLLibrary(library_path) + + if not hostname: + hostname = get_ip() + port = int(self.config.kv_port) + port_offset + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + + # Each card corresponds to a ZMQ address. + self.zmq_address = f"{self._hostname}:{self._port}" + + # The `http_port` must be consistent with the port of OpenAI. + self.http_address = ( + f"{self._hostname}:" + f"{self.config.kv_connector_extra_config['http_port']}") + + # If `proxy_ip` or `proxy_port` is `""`, + # then the ping thread will not be enabled. + proxy_ip = self.config.get_from_extra_config("proxy_ip", "") + proxy_port = self.config.get_from_extra_config("proxy_port", "") + if proxy_ip == "" or proxy_port == "": + self.proxy_address = "" + else: + self.proxy_address = proxy_ip + ":" + proxy_port + + self.context = zmq.Context() + self.router_socket = self.context.socket(zmq.ROUTER) + self.router_socket.bind(f"tcp://{self.zmq_address}") + + self.poller = zmq.Poller() + self.poller.register(self.router_socket, zmq.POLLIN) + + self.send_store_cv = threading.Condition() + self.send_queue_cv = threading.Condition() + self.recv_store_cv = threading.Condition() + + self.send_stream = torch.cuda.Stream() + self.recv_stream = torch.cuda.Stream() + + mem_pool_size_gb = self.config.get_from_extra_config( + "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) + self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) * + 1024**3) # GB + + # The sending type includes tree mutually exclusive options: + # PUT, GET, PUT_ASYNC. + self.send_type = self.config.get_from_extra_config("send_type", "PUT") + if self.send_type == "GET": + # tensor_id: torch.Tensor + self.send_store: dict[str, torch.Tensor] = {} + else: + # PUT or PUT_ASYNC + # tensor_id: torch.Tensor + self.send_queue: deque[list[Any]] = deque() + self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} + if self.send_type == "PUT_ASYNC": + self._send_thread = threading.Thread(target=self._send_async, + daemon=True) + self._send_thread.start() + + # tensor_id: torch.Tensor/(addr, dtype, shape) + self.recv_store: dict[str, Any] = {} + self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {} + self.socks: dict[str, Any] = {} # remote_address: client socket + self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank) + + self.buffer_size = 0 + self.buffer_size_threshold = float(self.config.kv_buffer_size) + + self.nccl_num_channels = self.config.get_from_extra_config( + "nccl_num_channels", "8") + + self._listener_thread = threading.Thread( + target=self._listen_for_requests, daemon=True) + self._listener_thread.start() + + self._ping_thread = None + if port_offset == 0 and self.proxy_address != "": + self._ping_thread = threading.Thread(target=self._ping, + daemon=True) + self._ping_thread.start() + + logger.info( + "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, " + "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_" + "threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank, + self.http_address, self.zmq_address, self.proxy_address, + self.send_type, self.buffer_size_threshold, self.nccl_num_channels) + + def _create_connect(self, remote_address: typing.Optional[str] = None): + assert remote_address is not None + if remote_address not in self.socks: + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + sock.connect(f"tcp://{remote_address}") + self.socks[remote_address] = sock + if remote_address in self.comms: + logger.info("👋comm exists, remote_address:%s, comms:%s", + remote_address, self.comms) + return sock, self.comms[remote_address] + + unique_id = self.nccl.ncclGetUniqueId() + data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)} + sock.send(msgpack.dumps(data)) + + with torch.cuda.device(self.device): + rank = 0 + with set_p2p_nccl_context(self.nccl_num_channels): + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + self.zmq_address, remote_address, rank) + + return self.socks[remote_address], self.comms[remote_address] + + def send_tensor( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.recv_store_cv.notify() + return True + else: + if self.send_type == "PUT": + return self._send_sync(tensor_id, tensor, remote_address) + elif self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append([tensor_id, remote_address, tensor]) + self.send_queue_cv.notify() + else: # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, + self.buffer_size, oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.debug( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, tensor_id, tensor_size, tensor.shape, + self.rank, self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) + + return True + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + if self.send_type == "PUT" or self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.recv_store_cv: + while tensor_id not in self.recv_store: + self.recv_store_cv.wait() + tensor = self.recv_store[tensor_id] + + if tensor is not None: + if isinstance(tensor, tuple): + addr, dtype, shape = tensor + tensor = self.pool.load_tensor(addr, dtype, shape, + self.device) + else: + self.buffer_size -= (tensor.element_size() * + tensor.numel()) + else: + duration = time.time() - start_time + logger.warning( + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " + "rank:%d", remote_address, tensor_id, duration * 1000, + self.rank) + return tensor + + # GET + if remote_address is None: + return None + + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + + data = {"cmd": "GET", "tensor_id": tensor_id} + sock.send(msgpack.dumps(data)) + + message = sock.recv() + data = msgpack.loads(message) + if data["ret"] != 0: + logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, tensor_id, data["ret"]) + return None + + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) + + self._recv(comm, tensor, rank ^ 1, self.recv_stream) + + return tensor + + def _listen_for_requests(self): + while True: + socks = dict(self.poller.poll()) + if self.router_socket in socks: + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + with set_p2p_nccl_context(self.nccl_num_channels): + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + self.router_socket.send_multipart( + [remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self._recv(comm, tensor, rank ^ 1, self.recv_stream) + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + # Store Tensor in memory pool + addr = self.pool.store_tensor(tensor) + tensor = (addr, tensor.dtype, tensor.shape) + else: + self.buffer_size += tensor_size + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart( + [remote_address, b"1"]) + tensor = None + logger.warning( + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, + remote_address.decode(), data) + + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self._have_received_tensor_id(tensor_id) + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": + str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + self._have_sent_tensor_id(tensor_id) + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + comm, rank = self.comms[remote_address.decode()] + self._send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + def _have_sent_tensor_id(self, tensor_id: str): + request_id = tensor_id.split('#')[0] + if request_id not in self.send_request_id_to_tensor_ids: + self.send_request_id_to_tensor_ids[request_id] = set() + self.send_request_id_to_tensor_ids[request_id].add(tensor_id) + + def _have_received_tensor_id(self, tensor_id: str): + request_id = tensor_id.split('#')[0] + if request_id not in self.recv_request_id_to_tensor_ids: + self.recv_request_id_to_tensor_ids[request_id] = set() + self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) + + def _send_async(self): + while True: + with self.send_queue_cv: + while not self.send_queue: + self.send_queue_cv.wait() + tensor_id, remote_address, tensor = self.send_queue.popleft() + if not self.send_queue: + self.send_queue_cv.notify() + self._send_sync(tensor_id, tensor, remote_address) + + def wait_for_sent(self): + if self.send_type == "PUT_ASYNC": + start_time = time.time() + with self.send_queue_cv: + while self.send_queue: + self.send_queue_cv.wait() + duration = time.time() - start_time + logger.debug( + "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" + " to be empty, rank:%d", duration * 1000, self.rank) + + def _send_sync( + self, + tensor_id: str, + tensor: torch.Tensor, + remote_address: typing.Optional[str] = None, + ) -> bool: + if remote_address is None: + return False + if remote_address not in self.socks: + self._create_connect(remote_address) + + sock = self.socks[remote_address] + comm, rank = self.comms[remote_address] + data = { + "cmd": "PUT", + "tensor_id": tensor_id, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + sock.send(msgpack.dumps(data)) + + response = sock.recv() + if response != b"0": + logger.error( + "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " + "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", + self.zmq_address, remote_address, rank, data, tensor.shape, + tensor.element_size() * tensor.numel() / 1024**3, + response.decode()) + return False + + self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + + if self.send_type == "PUT_ASYNC": + self._have_sent_tensor_id(tensor_id) + + return True + + def get_finished( + self, finished_req_ids: set[str], forward_context: "ForwardContext" + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + + # Clear the buffer upon request completion. + for request_id in finished_req_ids: + for layer_name in forward_context.no_compile_layers: + tensor_id = request_id + "#" + layer_name + if tensor_id in self.recv_store: + with self.recv_store_cv: + tensor = self.recv_store.pop(tensor_id, None) + self.send_request_id_to_tensor_ids.pop( + request_id, None) + self.recv_request_id_to_tensor_ids.pop( + request_id, None) + addr = 0 + if isinstance(tensor, tuple): + addr, _, _ = tensor + self.pool.free(addr) + + # TODO:Retrieve requests that have already sent the KV cache. + finished_sending: set[str] = set() + + # TODO:Retrieve requests that have already received the KV cache. + finished_recving: set[str] = set() + + return finished_sending or None, finished_recving or None + + def _ping(self): + sock = self.context.socket(zmq.DEALER) + sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) + logger.debug("ping start, zmq_address:%s", self.zmq_address) + sock.connect(f"tcp://{self.proxy_address}") + data = { + "type": "P" if self.config.is_kv_producer else "D", + "http_address": self.http_address, + "zmq_address": self.zmq_address + } + while True: + sock.send(msgpack.dumps(data)) + time.sleep(3) + + def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with torch.cuda.stream(stream): + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + comm, cudaStream_t(stream.cuda_stream)) + stream.synchronize() + + def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = current_stream() + + with torch.cuda.stream(stream): + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + comm, cudaStream_t(stream.cuda_stream)) + stream.synchronize() + + def close(self) -> None: + self._listener_thread.join() + if self.send_type == "PUT_ASYNC": + self._send_thread.join() + if self._ping_thread is not None: + self._ping_thread.join() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py new file mode 100644 index 0000000..02e3bc6 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import atexit +import ctypes +import math +from dataclasses import dataclass + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@dataclass +class MemoryBlock: + size: int + addr: int + + +"""A memory pool for managing pinned host memory allocations for tensors. + +This class implements a buddy allocation system to efficiently manage pinned +host memory for tensor storage. It supports allocation, deallocation, and +tensor storage/retrieval operations. + +Key Features: +- Uses power-of-two block sizes for efficient buddy allocation +- Supports splitting and merging of memory blocks +- Provides methods to store CUDA tensors in pinned host memory +- Allows loading tensors from pinned memory back to device +- Automatically cleans up memory on destruction + +Attributes: + max_block_size (int): Maximum block size (rounded to nearest power of two) + min_block_size (int): Minimum block size (rounded to nearest power of two) + free_lists (dict): Dictionary of free memory blocks by size + allocated_blocks (dict): Dictionary of currently allocated blocks + base_tensor (torch.Tensor): Base pinned memory tensor + base_address (int): Base memory address of the pinned memory region + +Example: + >>> pool = TensorMemoryPool(max_block_size=1024*1024) + >>> tensor = torch.randn(100, device='cuda') + >>> addr = pool.store_tensor(tensor) + >>> loaded_tensor = pool.load_tensor(addr, tensor.dtype, + ... tensor.shape, 'cuda') + >>> pool.free(addr) +""" + + +class TensorMemoryPool: + """Initializes the memory pool with given size constraints. + + Args: + max_block_size (int): Maximum size of memory blocks to manage + min_block_size (int, optional): Minimum size of memory blocks + to manage. Defaults to 512. + + Raises: + ValueError: If block sizes are invalid or max_block_size is less + than min_block_size + """ + + def __init__(self, max_block_size: int, min_block_size: int = 512): + if max_block_size <= 0 or min_block_size <= 0: + raise ValueError("Block sizes must be positive") + if max_block_size < min_block_size: + raise ValueError( + "Max block size must be greater than min block size") + + self.max_block_size = self._round_to_power_of_two(max_block_size) + self.min_block_size = self._round_to_power_of_two(min_block_size) + + self.free_lists: dict[int, dict[int, MemoryBlock]] = {} + self.allocated_blocks: dict[int, MemoryBlock] = {} + + self._initialize_free_lists() + self._allocate_pinned_memory() + + atexit.register(self.cleanup) + + def _round_to_power_of_two(self, size: int) -> int: + return 1 << (size - 1).bit_length() + + def _initialize_free_lists(self): + size = self.max_block_size + while size >= self.min_block_size: + self.free_lists[size] = {} + size //= 2 + + def _allocate_pinned_memory(self): + self.base_tensor = torch.empty(self.max_block_size // 4, + dtype=torch.float32, + pin_memory=True) + self.base_address = self.base_tensor.data_ptr() + initial_block = MemoryBlock(size=self.max_block_size, + addr=self.base_address) + self.free_lists[self.max_block_size][ + initial_block.addr] = initial_block + logger.debug("TensorMemoryPool, base_address:", self.base_address, + self.base_address % self.max_block_size) + + def allocate(self, size: int) -> int: + """Allocates a memory block of at least the requested size. + + Args: + size (int): Minimum size of memory to allocate + + Returns: + int: Address of the allocated memory block + + Raises: + ValueError: If size is invalid or insufficient memory is available + """ + if size <= 0: + raise ValueError("Allocation size must be positive") + + required_size = self._round_to_power_of_two( + max(size, self.min_block_size)) + if required_size > self.max_block_size: + raise ValueError("Requested size exceeds maximum block size") + + current_size = required_size + while current_size <= self.max_block_size: + if self.free_lists[current_size]: + _, block = self.free_lists[current_size].popitem() + self._split_block(block, required_size) + self.allocated_blocks[block.addr] = block + return block.addr + current_size *= 2 + + raise ValueError("Insufficient memory") + + def _split_block(self, block: MemoryBlock, required_size: int): + while (block.size > required_size + and block.size // 2 >= self.min_block_size): + buddy_size = block.size // 2 + buddy_addr = block.addr + buddy_size + + buddy = MemoryBlock(size=buddy_size, addr=buddy_addr) + block.size = buddy_size + + self.free_lists[buddy_size][buddy.addr] = buddy + + def free(self, addr: int): + """Frees an allocated memory block. + + Args: + addr (int): Address of the block to free + + Raises: + ValueError: If address is invalid or not allocated + """ + if addr not in self.allocated_blocks: + raise ValueError("Invalid address to free") + + block = self.allocated_blocks.pop(addr) + self._merge_buddies(block) + + def _merge_buddies(self, block: MemoryBlock): + MAX_MERGE_DEPTH = 30 + depth = 0 + + while depth < MAX_MERGE_DEPTH: + buddy_offset = block.size if (block.addr - self.base_address) % ( + 2 * block.size) == 0 else -block.size + buddy_addr = block.addr + buddy_offset + buddy = self.free_lists[block.size].get(buddy_addr) + if buddy: + del self.free_lists[buddy.size][buddy.addr] + merged_addr = min(block.addr, buddy.addr) + merged_size = block.size * 2 + block = MemoryBlock(size=merged_size, addr=merged_addr) + depth += 1 + else: + break + self.free_lists[block.size][block.addr] = block + + def store_tensor(self, tensor: torch.Tensor) -> int: + """Stores a CUDA tensor in pinned host memory. + + Args: + tensor (torch.Tensor): CUDA tensor to store + + Returns: + int: Address where the tensor is stored + + Raises: + ValueError: If tensor is not on CUDA or allocation fails + """ + if not tensor.is_cuda: + raise ValueError("Only CUDA tensors can be stored") + + size = tensor.element_size() * tensor.numel() + addr = self.allocate(size) + block = self.allocated_blocks[addr] + + if block.size < size: + self.free(addr) + raise ValueError( + f"Allocated block size {block.size} is smaller than " + f"required size {size}") + + try: + buffer = (ctypes.c_byte * block.size).from_address(block.addr) + cpu_tensor = torch.frombuffer(buffer, + dtype=tensor.dtype, + count=tensor.numel()).reshape( + tensor.shape) + except ValueError as err: + self.free(addr) + raise ValueError(f"Failed to create tensor view: {err}") from err + + cpu_tensor.copy_(tensor) + + return addr + + def load_tensor(self, addr: int, dtype: torch.dtype, + shape: tuple[int, ...], device) -> torch.Tensor: + """Loads a tensor from pinned host memory to the specified device. + + Args: + addr (int): Address where tensor is stored + dtype (torch.dtype): Data type of the tensor + shape (tuple[int, ...]): Shape of the tensor + device: Target device for the loaded tensor + + Returns: + torch.Tensor: The loaded tensor on the specified device + + Raises: + ValueError: If address is invalid or sizes don't match + """ + if addr not in self.allocated_blocks: + raise ValueError("Invalid address to load") + + block = self.allocated_blocks[addr] + num_elements = math.prod(shape) + dtype_size = torch.tensor([], dtype=dtype).element_size() + required_size = num_elements * dtype_size + + if required_size > block.size: + raise ValueError("Requested tensor size exceeds block size") + + buffer = (ctypes.c_byte * block.size).from_address(block.addr) + cpu_tensor = torch.frombuffer(buffer, dtype=dtype, + count=num_elements).reshape(shape) + + cuda_tensor = torch.empty(shape, dtype=dtype, device=device) + + cuda_tensor.copy_(cpu_tensor) + + return cuda_tensor + + def cleanup(self): + """Cleans up all memory resources and resets the pool state.""" + self.free_lists.clear() + self.allocated_blocks.clear() + if hasattr(self, 'base_tensor'): + del self.base_tensor + + def __del__(self): + self.cleanup() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py new file mode 100644 index 0000000..3c574d0 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -0,0 +1,389 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import hashlib +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import safetensors +import torch + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import MLACommonMetadata +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class ReqMeta: + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Is store or load + is_store: bool + + @staticmethod + def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, + is_store: bool) -> "ReqMeta": + valid_num_tokens = align_to_block_size(len(token_ids), block_size) + token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = block_offsets.reshape((1, block_size)) + \ + block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = slot_mapping.flatten()[:valid_num_tokens] + return ReqMeta( + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + is_store=is_store, + ) + + +@dataclass +class SharedStorageConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] + + def __init__(self): + self.requests = [] + + def add_request( + self, + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + ) -> None: + self.requests.append( + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store)) + + +class SharedStorageConnector(KVConnectorBase_V1): + # NOTE: This is Simple debug implementation of the KV connector. + # It save / load the KV cache to / from the disk. + # It does extra work which will overwrite the existing prefix-cache in GPU + # - to remove the overhead, need to add some "mask" in the ReqMeta class + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._block_size = vllm_config.cache_config.block_size + self._requests_need_load: dict[str, Request] = {} + transfer_config = vllm_config.kv_transfer_config + self._storage_path = transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp") + logger.info(vllm_config.kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + """Start loading the KV cache from the connector buffer to vLLM's + paged KV buffer. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + """ + attn_metadata = forward_context.attn_metadata + + def inject_kv_into_layer( + dst_kv_cache_layer: torch.Tensor, + src_kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Inject the KV cache into the layer. + + Args: + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not + using MLA, [num_pages, page_size, xxx] otherwise. + src_kv_cache (torch.Tensor): the source KV cache. In shape + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + otherwise. + slot_mapping (torch.Tensor): the slot mapping. In shape + [num_tokens]. + """ + dst_kv_cache_layer_shape = dst_kv_cache_layer.shape + if isinstance(attn_metadata, MLACommonMetadata): + num_pages = dst_kv_cache_layer_shape[0] + page_size = dst_kv_cache_layer_shape[1] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + num_pages * page_size, -1) + dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + else: + num_pages = dst_kv_cache_layer_shape[1] + page_size = dst_kv_cache_layer_shape[2] + dst_kv_cache_layer = dst_kv_cache_layer.reshape( + 2, num_pages * page_size, -1) + dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache + dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) + + # Get the metadata + metadata: KVConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, SharedStorageConnectorMetadata) + + if metadata is None: + logger.warning( + "In connector.start_load_kv, but the connector metadata is None" + ) + return + + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + logger.warning( + "In connector.start_load_kv, but the attn_metadata is None") + return + + # Load the KV for each request each layer + for request in metadata.requests: + if request.is_store: + continue + logger.info("Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping)) + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache_layer = attn_layer.kv_cache[\ + forward_context.virtual_engine] + + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = safetensors.torch.load_file( + filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, + request.slot_mapping) + + def wait_for_layer_load(self, layer_name: str) -> None: + """Blocking until the KV for a specific layer is loaded into vLLM's + paged buffer. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + return + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + + def extract_kv_from_layer( + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if isinstance(attn_metadata, MLACommonMetadata): + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, + ...] + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, SharedStorageConnectorMetadata) + for request in connector_metadata.requests: + if request.is_store: + filename = self._generate_filename_debug( + layer_name, request.token_ids) + kv_cache = extract_kv_from_layer(kv_layer, + request.slot_mapping) + tensors = {"kv_cache": kv_cache.detach().cpu()} + safetensors.torch.save_file(tensors, filename) + + def wait_for_save(self): + return + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + # NOTE: in this debug implementation, we assume that the prompt is + # cached_prompt + newly_generated_single_token + # Therefore, we use prompt_token_ids[:-1] to determine the folder name + + # NOTE: in current v1 scheduler, the num_computed_tokens is aligned + # with the block granularity. And it expects the returned blocks and + # num_computed_tokens to also be aligned with the block granularity. + if not self._found_match_for_request(request): + return 0, False + + logger.info("External Cache Hit!") + + # Now, first num_tokens_to_check tokens are hit, we need to prepare + # the metadata for the worker connector to correctly load the KV + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + + return num_tokens_to_check - num_computed_tokens, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after block allocation. + + If blocks were allocated, add to _requests_need_load, + such that we load the KVs in the next forward pass. + """ + if num_external_tokens > 0: + self._requests_need_load[request.request_id] = request + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = SharedStorageConnectorMetadata() + + total_need_load = 0 + for new_req in scheduler_output.scheduled_new_reqs: + if new_req.req_id in self._requests_need_load: + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=False) + total_need_load += 1 + else: + # NOTE: here, we set the store and load being exclusive, + # but a single request can have both store and load. + # NOTE(rob): for this debug implementation, we only cache + # the original prompt tokens. + if not self._found_match_for_request(new_req): + meta.add_request(token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + is_store=True) + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + num_computed_tokens = cached_reqs.num_computed_tokens[i] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + new_block_ids = cached_reqs.new_block_ids[i] + resumed_from_preemption = cached_reqs.resumed_from_preemption[i] + + # NOTE(rob): here we rely on the resumed requests being + # the first N requests in the list scheduled_cache_reqs. + if not resumed_from_preemption: + break + if req_id in self._requests_need_load: + # NOTE(rob): cached_req_data does not have the full + # list of token ids (only new tokens). So we look it + # up in the actual request object. + request = self._requests_need_load[req_id] + total_tokens = num_computed_tokens + num_new_tokens + token_ids = request.all_token_ids[:total_tokens] + + # NOTE(rob): For resumed req, new_block_ids is all + # of the block_ids for the request. + block_ids = new_block_ids[0] + + meta.add_request(token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + is_store=False) + total_need_load += 1 + + assert total_need_load == len(self._requests_need_load) + self._requests_need_load.clear() + return meta + + # ============================== + # Helper functions + # ============================== + + def _found_match_for_request( + self, + request: "Request", + ) -> bool: + """Check if the cache is hit for the request. + """ + num_tokens_to_check = align_to_block_size( + len(request.prompt_token_ids) - 1, self._block_size) + foldername = self._generate_foldername_debug(torch.tensor( + request.prompt_token_ids)[:num_tokens_to_check], + create_folder=False) + return os.path.exists(foldername) + + def _generate_foldername_debug( + self, + input_ids: torch.Tensor, + create_folder=False, + ) -> str: + """Generate a folder name based on the hash of the bytes of the input + ids. + """ + input_ids_bytes = input_ids.numpy().tobytes() + input_ids_hash = hashlib.md5(input_ids_bytes, + usedforsecurity=False).hexdigest() + foldername = os.path.join(self._storage_path, input_ids_hash) + if create_folder: + os.makedirs(foldername, exist_ok=True) + return foldername + + def _generate_filename_debug( + self, + layer_name: str, + input_ids: torch.Tensor, + ) -> str: + """Generate a file name based on the layer name and the hash + of the bytes of the input ids. + """ + foldername = self._generate_foldername_debug(input_ids, + create_folder=True) + return os.path.join(foldername, f"{layer_name}.safetensors") + + +def align_to_block_size(num_tokens: int, block_size) -> int: + """Align the number of tokens to the block size. + """ + return (num_tokens - 1) // block_size * block_size diff --git a/vllm/distributed/kv_transfer/kv_connector_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py new file mode 100644 index 0000000..8633fda --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector_agent.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states +""" +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig + +import torch + +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + + self.config = config + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." + + self.connector = KVConnectorFactory.create_connector_v0( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches) diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 0000000..eef1426 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +This file also contains a new class `KVStoreBufferBase` that allows developers +to manage the KVCache buffer as a simple key-value storage buffer with basic +put/get operations. + +These classes above are abstracted behind class `KVCacheBufferBase`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVCacheBufferBase(ABC): + """ + Abstract base class for a KVCache buffer. + """ + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + KVCache buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVLookupBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + list[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + +class KVStoreBufferBase(KVCacheBufferBase): + """ + Abstract base class for a KVCache storage buffer with key-value semantics. + This class provides a simple key-value storage buffer abstract with basic + put/get operations, which enables flexible KVCache transfer granular + control. + + The functionality is similar to a distributed key-value store, where: + - Key: A unique string identifier for the cached entry + - Value: + - Tensor to be stored and retrieved + - None (indicating deletion or empty value) + """ + + @abstractmethod + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + """Store a key-value pair in the buffer. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + value (Optional[torch.Tensor]): Tensor to be stored. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Retrieve a value from the buffer by key. + + Args: + key (str): Unique identifier for a tensor, this tensor could be the + key cache tensor, value cache tensor, or hidden state tensor + generated during model forwarding. + + Returns: + Optional[torch.Tensor]: Stored tensor if exists, None otherwise. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py new file mode 100644 index 0000000..4381aad --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains a new class `MooncakeStore` that allows developers to +think of KV cache transfer operations as putting new KV cache entries +into a remote KVStore-based lookup buffer and getting existing KV caches +from this remote lookup buffer. +""" +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVStoreBufferBase) +from vllm.logger import init_logger + +DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB + +logger = init_logger(__name__) + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeStoreConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", + DEFAULT_GLOBAL_SEGMENT_SIZE), + local_buffer_size=config.get("local_buffer_size", + DEFAULT_LOCAL_BUFFER_SIZE), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address"), + ) + + @staticmethod + def load_from_env() -> 'MooncakeStoreConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_file_path) + + +class MooncakeStore(KVStoreBufferBase): + + def __init__( + self, + config: VllmConfig, + ): + + try: + from mooncake.store import MooncakeDistributedStore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + try: + self.store = MooncakeDistributedStore() + self.config = MooncakeStoreConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + + self.store.setup(self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) + + except ValueError as e: + logger.error("Configuration loading failed: %s", e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + def close(self): + # MooncakeDistributedStore will automatically call the destructor, so + # it is unnecessary to close it manually. + pass + + def put( + self, + key: str, + value: Optional[torch.Tensor], + ) -> None: + # A message queue needs to be introduced before making it asynchronous. + if value is not None: + self._put_impl(key, value) + + def get( + self, + key: str, + ) -> Optional[torch.Tensor]: + # A message queue needs to be introduced before making it asynchronous. + value = self._get_impl(key) + return value + + def _put_impl( + self, + key: str, + value: torch.Tensor, + ) -> None: + """Put KVCache to Mooncake Store""" + device_id = value.device.index if value.device.type == 'cuda' else -1 + device_tensor = torch.tensor(device_id, dtype=torch.int32) + value_bytes = safetensors_save({ + "tensor": value, + "device_id": device_tensor + }) + try: + self.store.put(key, value_bytes) + except TypeError as err: + logger.error("Failed to put value into Mooncake Store: %s", err) + raise TypeError("Mooncake Store Put Type Error.") from err + + def _get_impl( + self, + key: str, + ) -> Optional[torch.Tensor]: + """Get KVCache from Mooncake Store""" + try: + data = self.store.get(key) + except TypeError as err: + logger.error("Failed to get value from Mooncake Store: %s", err) + raise TypeError("Mooncake Store Get Type Error.") from err + + if data: + loaded_tensors = safetensors_load(data) + tensor = loaded_tensors["tensor"] + device_id_tensor = loaded_tensors["device_id"] + device_id = int(device_id_tensor.item()) + device = torch.device( + 'cuda', device_id) if device_id >= 0 else torch.device('cpu') + return tensor.to(device) + + return None diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py new file mode 100644 index 0000000..a0ff7c3 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" + Implements a distributed key-value (KV) cache transfer mechanism. + + Key Features: + - Distributed KV cache transmission using PyNccl pipes. + - Non-blocking `insert`, blocking `drop_select`. + - Use CPU signal pipe to avoid racing condition + - Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. +""" +import threading +from collections import deque +from typing import Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SimpleBuffer(KVLookupBufferBase): + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: deque[list[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_cv = threading.Condition() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0], device="cpu") + self.end_signal = None + + def _matches(self, tokens_roi_sender: list[torch.Tensor], + tokens_roi_recver: list[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError(f"Unknown data type {type(data)}") + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + data_size = sum([self._get_element_size(data) for data in buffer_item]) + + with self.buffer_cv: + if self.buffer_size + data_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size + data_size > self.buffer_size_threshold: + self.buffer_cv.wait() + + self.buffer_size += data_size + self.buffer.append(buffer_item) + self.buffer_cv.notify() + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [input_tokens, roi] + + def is_buffer_available( + tokens_roi_recver: list[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + for _ in range(len(self.buffer)): + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + self.buffer_cv.notify() + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone().float() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 0000000..1423fd0 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file defines an interface `KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All classes instantiated from this interface are assumed to be a FIFO pipe. + +If your distributed communication platform already supports key-value lookup, +you can bypass this interface and directly start from `kv_lookup_buffer`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py new file mode 100644 index 0000000..0b560d1 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import os +import struct +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import zmq +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger +from vllm.utils import join_host_port, make_zmq_path, split_host_port + +logger = init_logger(__name__) +NONE_INT = -150886311 + + +@dataclass +class MooncakeTransferEngineConfig: + prefill_url: str + decode_url: str + metadata_backend: Union[str, None] + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + prefill_url=config.get("prefill_url"), + decode_url=config.get("decode_url"), + metadata_backend=config.get("metadata_backend", None), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> 'MooncakeTransferEngineConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ.""" + + def __init__(self, kv_rank: int, local_rank: int): + try: + from mooncake.engine import TransferEngine + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + + self.engine = TransferEngine() + self.local_rank = local_rank + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Mooncake Configuration loaded successfully.") + except ValueError as e: + logger.error(e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + prefill_host, base_prefill_port = split_host_port( + self.config.prefill_url) + decode_host, base_decode_port = split_host_port(self.config.decode_url) + + # Avoid ports conflict when running prefill and decode on the same node + if prefill_host == decode_host and \ + base_prefill_port == base_decode_port: + base_decode_port = base_decode_port + 100 + + prefill_port = base_prefill_port + self.local_rank + decode_port = base_decode_port + self.local_rank + self.prefill_url = join_host_port(prefill_host, prefill_port) + self.decode_url = join_host_port(decode_host, decode_port) + + self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, + self.config.metadata_server, self.config.protocol, + self.config.device_name, self.config.metadata_backend) + + self.remote_url = (self.decode_url + if kv_rank == 0 else self.prefill_url) + + # Initialize ZeroMQ context and sockets + self.context = zmq.Context() # type: ignore[attr-defined] + self.sender_socket = self.context.socket(zmq.constants.PUSH) + self.receiver_socket = self.context.socket(zmq.constants.PULL) + self.sender_ack = self.context.socket(zmq.constants.PULL) + self.receiver_ack = self.context.socket(zmq.constants.PUSH) + + self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) + self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, + decode_host, base_decode_port) + + def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int, + d_host: str, d_port: int) -> None: + """Set up ZeroMQ sockets for sending and receiving data.""" + # Offsets < 8 are left for initialization in case tp and pp are enabled + p_rank_offset = p_port + 8 + self.local_rank * 2 + d_rank_offset = d_port + 8 + self.local_rank * 2 + if kv_rank == 0: + self.sender_socket.bind( + make_zmq_path("tcp", p_host, p_rank_offset + 1)) + self.receiver_socket.connect( + make_zmq_path("tcp", d_host, d_rank_offset + 1)) + self.sender_ack.connect( + make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.receiver_ack.bind( + make_zmq_path("tcp", p_host, p_rank_offset + 2)) + else: + self.receiver_socket.connect( + make_zmq_path("tcp", p_host, p_rank_offset + 1)) + self.sender_socket.bind( + make_zmq_path("tcp", d_host, d_rank_offset + 1)) + self.receiver_ack.bind( + make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.sender_ack.connect( + make_zmq_path("tcp", p_host, p_rank_offset + 2)) + + def initialize(self, local_hostname: str, metadata_server: str, + protocol: str, device_name: str, + metadata_backend: Union[str, None]) -> None: + """Initialize the mooncake instance.""" + if metadata_backend is None: + self.engine.initialize(local_hostname, metadata_server, protocol, + device_name) + else: + supported_backend = ["etcd", "redis"] + metadata_backend = metadata_backend.lower() + if metadata_backend not in supported_backend: + raise ValueError( + "Mooncake Configuration error. `metadata_backend`" + f" should be one of {supported_backend}.") + + self.engine.initialize_ext(local_hostname, metadata_server, + protocol, device_name, metadata_backend) + + def allocate_managed_buffer(self, length: int) -> int: + """Allocate a managed buffer of the specified length.""" + ret = self.engine.allocate_managed_buffer(length) + if ret <= 0: + logger.error("Allocation Return Error") + raise Exception("Allocation Return Error") + return ret + + def free_managed_buffer(self, buffer: int, length: int) -> int: + """Free a previously allocated managed buffer.""" + return self.engine.free_managed_buffer(buffer, length) + + def transfer_sync(self, buffer: int, peer_buffer_address: int, + length: int) -> int: + """Synchronously transfer data to the specified address.""" + ret = self.engine.transfer_sync_read(self.remote_url, buffer, + peer_buffer_address, length) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + + def write_bytes_to_buffer(self, buffer: int, user_data: bytes, + length: int) -> int: + """Write bytes to the allocated buffer.""" + return self.engine.write_bytes_to_buffer(buffer, user_data, length) + + def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: + """Read bytes from the allocated buffer.""" + return self.engine.read_bytes_from_buffer(buffer, length) + + def wait_for_ack(self, src_ptr: int, length: int) -> None: + """Asynchronously wait for ACK from the receiver.""" + ack = self.sender_ack.recv() + if ack != b'ACK': + logger.error("Failed to receive ACK from the receiver") + + self.free_managed_buffer(src_ptr, length) + + def send_bytes(self, user_data: bytes) -> None: + """Send bytes to the remote process.""" + length = len(user_data) + src_ptr = self.allocate_managed_buffer(length) + self.write_bytes_to_buffer(src_ptr, user_data, length) + self.sender_socket.send_multipart( + [struct.pack("!Q", src_ptr), + struct.pack("!Q", length)]) + self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) + + def recv_bytes(self) -> bytes: + """Receive bytes from the remote process.""" + data = self.receiver_socket.recv_multipart() + src_ptr = struct.unpack("!Q", data[0])[0] + length = struct.unpack("!Q", data[1])[0] + dst_ptr = self.allocate_managed_buffer(length) + self.transfer_sync(dst_ptr, src_ptr, length) + ret = self.read_bytes_from_buffer(dst_ptr, length) + + # Buffer cleanup + self.receiver_ack.send(b'ACK') + self.free_managed_buffer(dst_ptr, length) + + return ret + + +class MooncakePipe(KVPipeBase): + """MooncakeTransferEngine based Pipe implementation.""" + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None): + """Initialize the mooncake pipe and set related parameters.""" + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + self.transfer_engine = MooncakeTransferEngine(self.kv_rank, + self.local_rank) + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + def _select_device(self, device: str) -> torch.device: + """Select available device (CUDA or CPU).""" + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def tensor_hash(self, tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + return hash(tensor.data_ptr()) + + def _send_impl(self, tensor: torch.Tensor) -> None: + """Implement the tensor sending logic using safetensors.""" + self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) + + def _recv_impl(self) -> torch.Tensor: + """Implement the tensor receiving logic using safetensors.""" + data = self.transfer_engine.recv_bytes() + return safetensors_load(data)["tensor"].to(self.device) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send tensor to the target process.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = tensor if tensor is not None else self.none_tensor + assert (len(tensor.shape) > 0) + self.transport_thread.submit(self._send_impl, tensor) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive tensor from other processes.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = self.transport_thread.submit(self._recv_impl).result() + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self) -> None: + """Cleanup logic when closing the pipe.""" + self.transfer_engine.sender_socket.close() + self.transfer_engine.receiver_socket.close() + self.transfer_engine.sender_ack.close() + self.transfer_engine.receiver_ack.close() + self.transfer_engine.context.term() # Terminate the ZMQ context + logger.info("Closed the transfer engine and cleaned up resources.") diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py new file mode 100644 index 0000000..09de0b6 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Optional + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe(KVPipeBase): + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + store_timeout = self.config.get_from_extra_config("store_timeout", 300) + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + store_timeout=store_timeout, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + else: + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv + + return send, recv + + def _select_device(self, device: str): + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Args: + tensor: The input tensor or None if no tensor is provided. + + Returns: + metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Args: + metadata: A dictionary with keys "dtype" and "shape", + describing the tensor's data type and shape. + + Returns: + buffer: A tensor of the specified type and shape, + allocated on `self.device`. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Args: + metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + metadata: A dictionary with keys "dtype" and "shape" + describing the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Args: + tensor: The input tensor to be sent, or `None` if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + buffer: The received tensor, or `None` if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Args: + tensor: The tensor to send, or `None` if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Args: + tensor: The received tensor, or `None` if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py new file mode 100644 index 0000000..60f1d5d --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, + KVConnectorRole) +from vllm.distributed.parallel_state import get_world_group + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None + + +def get_kv_transfer_group() -> KVConnectorBaseType: + assert _KV_CONNECTOR_AGENT is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_CONNECTOR_AGENT + + +def has_kv_transfer_group() -> bool: + return _KV_CONNECTOR_AGENT is not None + + +def is_v1_kv_transfer_group( + connector: Optional[KVConnectorBaseType] = None) -> bool: + """Check if the KV connector is the v1 connector. + If the argument is None, it will check the global KV connector + + Args: + connector: The KV connector to check. If None, it will check the + global KV connector. + + Note: + This function will no-longer be needed after the v1 KV connector + becomes the default. + """ + if connector is None: + connector = _KV_CONNECTOR_AGENT + + if connector is None: + return False + + return isinstance(connector, KVConnectorBase_V1) + + +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_CONNECTOR_AGENT + + if vllm_config.kv_transfer_config is None: + return + + if (vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None): + if envs.VLLM_USE_V1: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + config=vllm_config, role=KVConnectorRole.WORKER) + else: + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config, + ) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py new file mode 100644 index 0000000..f72f769 --- /dev/null +++ b/vllm/distributed/parallel_state.py @@ -0,0 +1,1386 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import gc +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import vllm.envs as envs +from vllm.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase) +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger +from vllm.utils import (direct_register_custom_op, get_distributed_init_method, + resolve_obj_by_qualname, supports_custom_op) + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: dict[str, Union[torch.Tensor, Any]] +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: list[tuple[str, Any]] = [] + tensor_list: list[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._reduce_scatter_out_place(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +def all_gather(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_gather_out_place(tensor, dim) + + +def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] * world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + +if supports_custom_op(): + from vllm.platforms import current_platform + direct_register_custom_op( + op_name="all_reduce", + op_func=all_reduce, + mutates_args=[], + fake_impl=all_reduce_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="all_gather", + op_func=all_gather, + mutates_args=[], + fake_impl=all_gather_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It manages both CPU and device + communication. + """ + + # available attributes: + rank: int # global rank + ranks: list[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_device_communicator: bool # whether to use device communicator + device_communicator: DeviceCommunicatorBase # device communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + from vllm.platforms import current_platform + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device( + f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + + self.device_communicator: DeviceCommunicatorBase = None # type: ignore + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + self.device_communicator = device_comm_cls( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + from vllm.distributed.device_communicators.shm_broadcast import ( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + from vllm.platforms import current_platform + self.use_custom_op_call = (current_platform.is_cuda_alike() + or current_platform.is_tpu()) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + # only cuda uses this function, + # so we don't abstract it into the base class + maybe_ca_context = nullcontext() + from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) + if self.device_communicator is not None: + assert isinstance(self.device_communicator, CudaCommunicator) + ca_comm = self.device_communicator.ca_comm + if ca_comm is not None: + maybe_ca_context = ca_comm.capture() # type: ignore + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we always make the all-reduce operation + out-of-place. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.use_custom_op_call: + return torch.ops.vllm.all_reduce(input_, + group_name=self.unique_name) + else: + return self._all_reduce_out_place(input_) + + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + return self.device_communicator.all_reduce(input_) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if self.use_custom_op_call: + return torch.ops.vllm.all_gather(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._all_gather_out_place(input_, dim) + + def _all_gather_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: + return self.device_communicator.all_gather(input_, dim) + + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if self.use_custom_op_call: + return torch.ops.vllm.reduce_scatter(input_, + dim, + world_size, + group_name=self.unique_name) + else: + return self._reduce_scatter_out_place(input_, dim) + + def _reduce_scatter_out_place(self, input_: torch.Tensor, + dim: int) -> torch.Tensor: + return self.device_communicator.reduce_scatter(input_, dim) + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + return self.device_communicator.gather(input_, dst, dim) + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: list[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: list[tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + self.device_communicator.send(tensor, dst) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + return self.device_communicator.recv(size, dtype, src) + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.device_communicator is not None: + self.device_communicator.destroy() + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + def prepare_communication_buffer_for_model(self, model: torch.nn.Module): + if self.device_communicator is not None: + self.device_communicator.prepare_communication_buffer_for_model( + model) + + def dispatch( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.device_communicator is not None: + return self.device_communicator.dispatch(hidden_states, + router_logits) + else: + return hidden_states, router_logits + + def combine(self, hidden_states) -> torch.Tensor: + if self.device_communicator is not None: + return self.device_communicator.combine(hidden_states) + else: + return hidden_states + + +_WORLD: Optional[GroupCoordinator] = None +_NODE_COUNT: Optional[int] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, ("world group is not initialized") + return _WORLD + + +def init_world_group(ranks: list[int], local_rank: int, + backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + ) + + +def init_model_parallel_group( + group_ranks: list[list[int]], + local_rank: int, + backend: str, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, ("tensor model parallel group is not initialized") + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + +_DP: Optional[GroupCoordinator] = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, ("data parallel group is not initialized") + return _DP + + +_EP: Optional[GroupCoordinator] = None + + +def get_ep_group() -> GroupCoordinator: + assert _EP is not None, ("expert parallel group is not initialized") + return _EP + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + context = GraphCaptureContext(torch.cuda.Stream(device=device)) + with get_tp_group().graph_capture(context), get_pp_group().graph_capture( + context): + yield context + + +logger = init_logger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None and config.parallel_config.data_parallel_size > 1: + parallel_config = config.parallel_config + # adjust to take into account data parallelism + # offset the rank by the data parallel rank + rank = parallel_config.data_parallel_rank * world_size + rank + local_rank = rank % torch.cuda.device_count() + # adjust the world size to take into account data parallelism + world_size = parallel_config.world_size_across_dp + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = get_distributed_init_method(ip, port) + logger.info( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, rank, distributed_init_method) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + if not torch.distributed.is_backend_available(backend): + logger.warning( + "Distributed backend %s is not available; " + "falling back to gloo.", backend) + assert torch.distributed.is_gloo_available(), ( + "Fallback Gloo backend is not available.") + backend = "gloo" + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD, _NODE_COUNT + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + _NODE_COUNT = _node_count(_WORLD.cpu_group) + logger.debug("Detected %d nodes in the distributed environment", + _NODE_COUNT) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + + data_parallel_size = 1 + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + data_parallel_size = config.parallel_config.data_parallel_size + + # the layout order is: ExternalDP x DP x PP x TP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + # DP is the data parallel group that is part of the model, + # all the ranks in the same DP group should generate simultaneously, + # i.e. the `generate` call in the same DP group should be called together, + # otherwise it will cause deadlock. + # to get group_ranks for each dimension, transpose that dimension to the + # last dimension, then reshape to 2D, then unbind the last dimension + all_ranks = torch.arange(world_size).reshape( + -1, data_parallel_size, pipeline_model_parallel_size, + tensor_model_parallel_size) # noqa + + # Build the tensor model-parallel groups. + global _TP + assert _TP is None, ("tensor model parallel group is already initialized") + group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp") + + # Build the pipeline model-parallel groups. + global _PP + assert _PP is None, ( + "pipeline model parallel group is already initialized") + group_ranks = all_ranks.transpose(2, 3).reshape( + -1, pipeline_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _PP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="pp") + + global _DP + assert _DP is None, ("data parallel group is already initialized") + group_ranks = all_ranks.transpose(1, + 3).reshape(-1, + data_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") + + global _EP + assert _EP is None, ("expert parallel group is already initialized") + group_ranks = all_ranks.transpose(1, 2).reshape( + -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") + + logger.info( + "rank %s in world size %s is assigned as " + "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, + _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, + _EP.rank_in_group) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size, backend) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert (pp_world_size == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +def prepare_communication_buffer_for_model(model: torch.nn.Module): + """Prepare the communication buffer for the model. + Traditional communication libraries like NCCL are almost + model agnostic. However, emerging new communication libraries like + MoE all2all (DeepEP) usually allocate the communication buffer + based on the model shape for optimal performance. + """ + if _TP is not None: + _TP.prepare_communication_buffer_for_model(model) + if _PP is not None: + _PP.prepare_communication_buffer_for_model(model) + if _DP is not None: + _DP.prepare_communication_buffer_for_model(model) + if _EP is not None: + _EP.prepare_communication_buffer_for_model(model) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return (_TP is not None and _PP is not None) + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def get_node_count() -> int: + """Return the total number of nodes in the distributed environment. """ + assert _NODE_COUNT is not None, ( + "distributed environment is not initialized") + return _NODE_COUNT + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + global _DP + if _DP: + _DP.destroy() + _DP = None + + global _EP + if _EP: + _EP.destroy() + _EP = None + + +def destroy_distributed_environment(): + global _WORLD, _NODE_COUNT + if _WORLD: + _WORLD.destroy() + _WORLD = None + _NODE_COUNT = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + ray.shutdown() + gc.collect() + from vllm.platforms import current_platform + empty_cache = current_platform.empty_cache + if empty_cache is not None: + empty_cache() + try: + if not current_platform.is_cpu(): + torch._C._host_emptyCache() + except AttributeError: + logger.warning( + "torch._C._host_emptyCache() only available in Pytorch >=2.5") + + +def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], + source_rank: int = 0) -> list[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + if isinstance(pg, ProcessGroup): + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + else: + rank = pg.rank + world_size = pg.world_size + ranks = list(range(world_size)) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + if isinstance(pg, ProcessGroup): + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg) + else: + pg.broadcast_obj(shm.name, src=source_rank) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + if isinstance(pg, ProcessGroup): + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg) + name = recv[0] + else: + name = pg.broadcast_obj(None, src=source_rank) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + if isinstance(pg, ProcessGroup): + torch.distributed.barrier(group=pg) + else: + pg.barrier() + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + + if isinstance(pg, ProcessGroup): + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + aggregated_data = is_in_the_same_node + else: + aggregated_data = torch.zeros_like(is_in_the_same_node) + for i in range(world_size): + rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) + aggregated_data += rank_data + + return [x == 1 for x in aggregated_data.tolist()] + + +def is_global_first_rank() -> bool: + """ + Check if the current process is the first rank globally across all + parallelism strategies (PP, TP, DP, EP, etc.). + + Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0` + or `get_pp_group().is_first_rank`, this function checks the global rank + across all parallelism dimensions. + + Returns: + bool: True if this is the global first rank (rank 0), False otherwise. + Returns True if distributed is not initialized (single process). + """ + try: + # If world group is available, use it for the most accurate check + global _WORLD + if _WORLD is not None: + return _WORLD.is_first_rank + + # If torch distributed is not initialized, assume single process + if not torch.distributed.is_initialized(): + return True + + # Fallback to torch's global rank + return torch.distributed.get_rank() == 0 + + except Exception: + # If anything goes wrong, assume this is the first rank + return True + + +def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: + """ + Returns the total number of nodes in the process group. + + Args: + pg: The process group to analyze + + Returns: + int: The total number of nodes + """ + if isinstance(pg, ProcessGroup): + world_size = torch.distributed.get_world_size(group=pg) + else: + world_size = pg.world_size + + if world_size == 1: + return 1 + + # Build node assignment map + node_assignment = [0] * world_size # rank -> node_id + next_node_id = 0 + + for current_rank in range(world_size): + if node_assignment[current_rank] != 0: + continue # Already assigned to a node + + # Assign current rank to a new node + next_node_id += 1 + node_assignment[current_rank] = next_node_id + + # Find all ranks on the same node as current_rank + same_node_flags = in_the_same_node_as(pg, current_rank) + for other_rank, is_same_node in enumerate(same_node_flags): + if is_same_node and node_assignment[other_rank] == 0: + node_assignment[other_rank] = next_node_id + + return next_node_id diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py new file mode 100644 index 0000000..0a786b4 --- /dev/null +++ b/vllm/distributed/tpu_distributed_utils.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_xla.distributed.spmd as xs +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) + +logger = init_logger(__name__) + + +class XlaQKVParallelLinear(nn.Module): + + def __init__(self, + qkv_linear: nn.Module, + mesh: Optional["xs.Mesh"] = None): + super().__init__() + assert isinstance(qkv_linear, QKVParallelLinear) + self.skip_bias_add = qkv_linear.skip_bias_add + self.return_bias = qkv_linear.return_bias + assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD." + + self.q_weight: Parameter + self.k_weight: Parameter + self.v_weight: Parameter + self.q_bias: Optional[Parameter] + self.k_bias: Optional[Parameter] + self.v_bias: Optional[Parameter] + self._load_weights_from_qkv_linear(qkv_linear) + if mesh is not None: + self._shard_weight(mesh) + + def _shard_weight(self, mesh: "xs.Mesh"): + self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False) + self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False) + self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False) + xs.mark_sharding(self.q_weight, mesh, ('x', None)) + xs.mark_sharding(self.k_weight, mesh, ('x', None)) + xs.mark_sharding(self.v_weight, mesh, ('x', None)) + if self.q_bias is not None: + assert self.k_bias is not None and self.v_bias is not None, \ + "QKVParallelLinear should have q, k, and v biases together." + self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.q_bias, mesh, ('x', )) + self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.k_bias, mesh, ('x', )) + self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False) + xs.mark_sharding(self.v_bias, mesh, ('x', )) + + def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): + q_proj_size, k_proj_size, _ = qkv_linear.output_sizes + # The weight of qkv linear is a concatenation of q, k, and v weights + # along the output dimension. + qkv_weight = qkv_linear.weight.data.cpu() + q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) + k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size], + requires_grad=False) + v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:], + requires_grad=False) + self.register_parameter("q_weight", q_weight) + self.register_parameter("k_weight", k_weight) + self.register_parameter("v_weight", v_weight) + + if qkv_linear.bias is not None: + q_bias = Parameter(qkv_linear.bias[:q_proj_size], + requires_grad=False) + k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size + + k_proj_size], + requires_grad=False) + v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:], + requires_grad=False) + self.register_parameter("q_bias", q_bias) + self.register_parameter("k_bias", k_bias) + self.register_parameter("v_bias", v_bias) + else: + self.register_parameter("q_bias", None) + self.register_parameter("k_bias", None) + self.register_parameter("v_bias", None) + + def forward(self, input): + # Same forward functionality as QKVParallelLinear, but doing qkv porj + # separately. + q_bias = self.q_bias if not self.skip_bias_add else None + k_bias = self.k_bias if not self.skip_bias_add else None + v_bias = self.v_bias if not self.skip_bias_add else None + q_proj = F.linear(input, self.q_weight, q_bias) + k_proj = F.linear(input, self.k_weight, k_bias) + v_proj = F.linear(input, self.v_weight, v_bias) + # The q/k/v projections will be split outside of the QKVParallelLinear. + # Because we are replacing XlaQKVParallelLinear with the + # QKVParallelLinear, we need to concatenate q, k, and v projections to + # match the output shape of the QKVParallelLinear implementation even if + # it seems to be redundant. + # The concat and the following split will be noop, and should be + # optimized away by the compiler. + qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) + output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \ + self.skip_bias_add else None + if not self.return_bias: + return qkv_proj + return qkv_proj, output_bias + + +def partition_column_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, ColumnParallelLinear) + xs.mark_sharding(layer.weight, mesh, ('x', None)) + logger.debug("Applied column-parallel sharding to %s", layer) + return layer + + +def partition_row_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, RowParallelLinear) + xs.mark_sharding(layer.weight, mesh, (None, 'x')) + logger.debug("Applied row-parallel sharding to %s", layer) + return layer + + +def partition_qkv_parallel_linear(layer: torch.nn.Module, + mesh: xs.Mesh) -> torch.nn.Module: + assert isinstance(layer, QKVParallelLinear) + xla_layer = XlaQKVParallelLinear(layer, mesh) + logger.debug("Applied qkv parallel sharding to %s", layer) + return xla_layer + + +MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([ + ("QKVParallelLinear", partition_qkv_parallel_linear), + ("ColumnParallelLinear", partition_column_parallel_linear), + ("RowParallelLinear", partition_row_parallel_linear), +]) + + +def get_fqn(module): + # Get the fully qualified name of the module + return module.__class__.__qualname__ + + +def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: + """ + Recursively check a PyTorch model and apply appropriate sharding based on + the MODULE_TYPE_TO_WRAPPING_FUNC mapping. + + Args: + model: torch.nn.Module to process + mesh: An XLA SPMD mesh object used for sharding + """ + + def _process_module(module, name=None, parent=None): + for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items(): + if get_fqn(module) == module_type: + wrapped_module = wrapping_func(module, mesh) + + assert parent is not None and name is not None, ( + "Top Level module is not expected to be wrapped.") + if wrapped_module is not module: + # Wrapped module and module are different py object. + # The original module should be replaced by the + # wrapped_module. + logger.debug("replace %s with %s", module, wrapped_module) + setattr(parent, name, wrapped_module) + + module = wrapped_module + break + + for child_name, child_module in list(module.named_children()): + _process_module(child_module, child_name, module) + + _process_module(model) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py new file mode 100644 index 0000000..67f7164 --- /dev/null +++ b/vllm/distributed/utils.py @@ -0,0 +1,536 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import os +import pickle +import socket +import sys +import time +import uuid +from collections import deque +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, Optional + +import torch +from torch.distributed import ProcessGroup, TCPStore +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + _unregister_process_group) +from torch.distributed.rendezvous import rendezvous + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils import get_tcp_uri, is_torch_equal_or_newer + +logger = init_logger(__name__) + +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_pp_indices(num_hidden_layers: int, pp_rank: int, + pp_size: int) -> tuple[int, int]: + """Try to evenly distribute layers across partitions. + + If the number of layers is not divisible by the number of partitions, + the remaining layers are evenly distributed across all but the last + partition. The last partition is excluded because it often contains an + additional norm layer and we are attempting to balance compute. + + If `pp_size > 2` and the number of remaining layers is + `0 < x <= pp_size - 2` then the remaining layers are evenly distributed + across the middle partitions. The first and last partitions are excluded + because they contain the input and output embeddings respectively and we + are attempting to reduce maximum memory consumption across partitions. + """ + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + else: + layers_per_partition = num_hidden_layers // pp_size + partitions = [layers_per_partition for _ in range(pp_size)] + + if remaining_layers := num_hidden_layers % pp_size: + for i in range(2, remaining_layers + 2): + partitions[-i] += 1 + logger.info( + "Hidden layers were unevenly partitioned: [%s]. " + "This can be manually overridden using the " + "VLLM_PP_LAYER_PARTITION environment variable", + ",".join(str(p) for p in partitions)) + + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + + return (start_layer, end_layer) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + + # stores a reference to the socket so that the file descriptor stays alive + socket: Optional[socket.socket] + + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: dict[int, int] = dataclasses.field( + default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: deque[tuple[str, + float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = { + i: 0 + for i in range(self.world_size) + } + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get( + f"send_to/{self.rank}/{self.recv_src_counter[src]}")) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = (f"broadcast_from/{src}/" + f"{self.broadcast_send_counter}") + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return obj + else: + key = (f"broadcast_from/{src}/" + f"{self.broadcast_recv_src_counter[src]}") + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self, timeout: float = 30.0): + """A robust barrier to synchronize all ranks. + + + Uses a multi-phase approach to ensure all processes reach the barrier + before proceeding: + + 1. Each process signals it has reached the barrier + + 2. Each process signals that it has confirmed the arrival of all other + ranks. + + 3. Rank 0 waits for all other ranks to signal their departure to ensure + that all ranks have departed the barrier first. + + Args: + timeout: Maximum time in seconds to wait for each phase (in seconds) + + + Raises: + RuntimeError: If coordination fails or times out + """ + # Generate a barrier ID that is globally unique + try: + if self.rank == 0: + barrier_id = f"barrier_{uuid.uuid4()}" + self.broadcast_obj(barrier_id, src=0) + else: + barrier_id = self.broadcast_obj(None, src=0) + except Exception as e: + raise RuntimeError("Failed to broadcast barrier_id") from e + + # Phase 1: Signal arrival at barrier + # Wait for all processes to arrive + # We need all ranks to confirm the arrival of all other ranks. + # This is the key synchronization point. + arrival_key = f"arrival_{barrier_id}_{self.rank}" + try: + self.store.set(arrival_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier arrival") from e + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < self.world_size: + # Check for timeout + cur_time = time.time() + if cur_time - start_time > timeout: + raise RuntimeError("Barrier timed out after %f seconds", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_arrived.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_arrived) < self.world_size: + sched_yield() + + # Phase 2: Signal departure from barrier + # We only care to block at this stage in rank 0, which runs the + # server side of the TCPStore. We want to make sure that all + # clients have departed the barrier before rank 0 in case the + # next thing after the barrier is a shutdown, including tearing + # down the TCPStore. Other ranks can exit the barrier immediately + # after signaling their departure. + departure_key = f"departure_{barrier_id}_{self.rank}" + try: + self.store.set(departure_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier departure") from e + + if self.rank != 0: + return + + # Make rank 0 wait for all processes to signal departure + start_time = time.time() + processes_departed: set[int] = set() + + while len(processes_departed) < self.world_size: + # Check for timeout + if time.time() - start_time > timeout: + raise RuntimeError("Barrier departure timed out after %f s", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_departed: + continue + + key = f"departure_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_departed.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_departed) < self.world_size: + sched_yield() + + # Clean up keys to avoid leaking memory in the store + for i in range(self.world_size): + try: + self.store.delete_key(f"arrival_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'arrival_{barrier_id}_{i}') + + try: + self.store.delete_key(f"departure_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'departure_{barrier_id}_{i}') + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + store_timeout: int = 300, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + launch_server = rank == 0 + if launch_server: + # listen on the specified interface (instead of 0.0.0.0) + listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + listen_socket.bind((host, port)) + listen_socket.listen() + listen_fd = listen_socket.fileno() + else: + listen_socket = None + listen_fd = None + + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=launch_server, + timeout=timedelta(seconds=store_timeout), + use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 + master_listen_fd=listen_fd, + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + socket=listen_socket, + data_expiration_seconds=data_expiration_seconds) + + +def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, + group_rank: int, group_size: int, + timeout: timedelta) -> ProcessGroup: + """ + Stateless init ProcessGroup with gloo backend compatible with + different torch versions. + """ + if is_torch_equal_or_newer("2.6"): + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + else: + options = ProcessGroup.Options(backend=backend) + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + options, + ) + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + if is_torch_equal_or_newer("2.6"): + # _set_default_backend is supported in torch >= 2.6 + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + + +def stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = get_tcp_uri(host, port) + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + if backend == "gloo": + return init_gloo_process_group(backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) + from vllm.platforms import current_platform + return current_platform.stateless_init_device_torch_dist_pg( + backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) + + +def stateless_destroy_torch_distributed_process_group( + pg: ProcessGroup) -> None: + """ + Destroy ProcessGroup returned by + stateless_init_torch_distributed_process_group(). + """ + if is_torch_equal_or_newer("2.7"): + pg.shutdown() + else: + # Lazy import for non-CUDA backends. + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + + _unregister_process_group(pg.group_name) diff --git a/vllm/engine/__init__.py b/vllm/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py new file mode 100644 index 0000000..06260f9 --- /dev/null +++ b/vllm/engine/arg_utils.py @@ -0,0 +1,1827 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +import os +import argparse +import copy +import dataclasses +import functools +import json +import sys +import threading +import warnings +from dataclasses import MISSING, dataclass, fields, is_dataclass +from itertools import permutations +from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, + Type, TypeVar, Union, cast, get_args, get_origin) + +import regex as re +import torch +from pydantic import TypeAdapter, ValidationError +from typing_extensions import TypeIs, deprecated + +import vllm.envs as envs +from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, + ConfigFormat, ConfigType, DecodingConfig, + DetailedTraceModules, Device, DeviceConfig, + DistributedExecutorBackend, GuidedDecodingBackend, + GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, + KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, PromptAdapterConfig, + SchedulerConfig, SchedulerPolicy, SpeculativeConfig, + TaskOption, TokenizerMode, TokenizerPoolConfig, + VllmConfig, get_attr_docs, get_field) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.plugins import load_general_plugins +from vllm.reasoning import ReasoningParserManager +from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 +from vllm.transformers_utils.utils import check_gguf_file +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, + GiB_bytes, get_ip, is_in_ray_actor) + +# yapf: enable + +logger = init_logger(__name__) + +# object is used to allow for special typing forms +T = TypeVar("T") +TypeHint = Union[type[Any], object] +TypeHintT = Union[type[T], object] + + +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: + + def _parse_type(val: str) -> T: + try: + if return_type is json.loads and not re.match("^{.*}$", val): + return cast(T, nullable_kvs(val)) + return return_type(val) + except ValueError as e: + raise argparse.ArgumentTypeError( + f"Value {val} cannot be converted to {return_type}.") from e + + return _parse_type + + +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + + return _optional_type + + +def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: + if not re.match("^{.*}$", val): + return str(val) + return optional_type(json.loads)(val) + + +@deprecated( + "Passing a JSON argument as a string containing comma separated key=value " + "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " + "string instead.") +def nullable_kvs(val: str) -> dict[str, int]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ + out_dict: dict[str, int] = {} + for item in val.split(","): + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts + + try: + parsed_value = int(value) + except ValueError as exc: + msg = f"Failed to parse value of item {key}={value}" + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value + + return out_dict + + +def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: + """Check if the type hint is a specific type.""" + return type_hint is type or get_origin(type_hint) is type + + +def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool: + """Check if the type hints contain a specific type.""" + return any(is_type(type_hint, type) for type_hint in type_hints) + + +def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT: + """Get the specific type from the type hints.""" + return next((th for th in type_hints if is_type(th, type)), None) + + +def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: + """Convert Literal type hints to argparse kwargs.""" + type_hint = get_type(type_hints, Literal) + choices = get_args(type_hint) + choice_type = type(choices[0]) + if not all(isinstance(choice, choice_type) for choice in choices): + raise ValueError( + "All choices must be of the same type. " + f"Got {choices} with types {[type(c) for c in choices]}") + return {"type": choice_type, "choices": sorted(choices)} + + +def is_not_builtin(type_hint: TypeHint) -> bool: + """Check if the class is not a built-in type.""" + return type_hint.__module__ != "builtins" + + +def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: + """Extract type hints from Annotated or Union type hints.""" + type_hints: set[TypeHint] = set() + origin = get_origin(type_hint) + args = get_args(type_hint) + + if origin is Annotated: + type_hints.update(get_type_hints(args[0])) + elif origin is Union: + for arg in args: + type_hints.update(get_type_hints(arg)) + else: + type_hints.add(type_hint) + + return type_hints + + +@functools.lru_cache(maxsize=30) +def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: + cls_docs = get_attr_docs(cls) + kwargs = {} + for field in fields(cls): + # Get the set of possible types for the field + type_hints: set[TypeHint] = get_type_hints(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + + # Get the default value of the field + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + default = field.default_factory() + + # Get the help text for the field + name = field.name + help = cls_docs[name].strip() + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} + + # Set other kwargs based on the type hints + json_tip = """\n\nShould either be a valid JSON string or JSON keys + passed individually. For example, the following sets of arguments are + equivalent:\n\n + - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n + - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n + Additionally, list elements can be passed individually using '+': + - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n + - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n""" + if dataclass_cls is not None: + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): + # Creates --no- and -- flags + kwargs[name]["action"] = argparse.BooleanOptionalAction + elif contains_type(type_hints, Literal): + kwargs[name].update(literal_to_kwargs(type_hints)) + elif contains_type(type_hints, tuple): + type_hint = get_type(type_hints, tuple) + types = get_args(type_hint) + tuple_type = types[0] + assert all(t is tuple_type for t in types if t is not Ellipsis), ( + "All non-Ellipsis tuple elements must be of the same " + f"type. Got {types}.") + kwargs[name]["type"] = tuple_type + kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) + elif contains_type(type_hints, list): + type_hint = get_type(type_hints, list) + types = get_args(type_hint) + assert len(types) == 1, ( + "List type must have exactly one type. Got " + f"{type_hint} with types {types}") + kwargs[name]["type"] = types[0] + kwargs[name]["nargs"] = "+" + elif contains_type(type_hints, int): + kwargs[name]["type"] = int + # Special case for large integers + if name in {"max_model_len", "max_num_batched_tokens"}: + kwargs[name]["type"] = human_readable_int + elif contains_type(type_hints, float): + kwargs[name]["type"] = float + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): + kwargs[name]["type"] = union_dict_and_str + elif contains_type(type_hints, dict): + kwargs[name]["type"] = parse_type(json.loads) + kwargs[name]["help"] += json_tip + elif (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints)): + kwargs[name]["type"] = str + else: + raise ValueError( + f"Unsupported type {type_hints} for argument {name}.") + + # If the type hint was a sequence of literals, use the helper function + # to update the type and choices + if get_origin(kwargs[name].get("type")) is Literal: + kwargs[name].update(literal_to_kwargs({kwargs[name]["type"]})) + + # If None is in type_hints, make the argument optional. + # But not if it's a bool, argparse will handle this better. + if type(None) in type_hints and not contains_type(type_hints, bool): + kwargs[name]["type"] = optional_type(kwargs[name]["type"]) + if kwargs[name].get("choices"): + kwargs[name]["choices"].append("None") + return kwargs + + +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + """Return argparse kwargs for the given Config dataclass. + + The heavy computation is cached via functools.lru_cache, and a deep copy + is returned so callers can mutate the dictionary without affecting the + cached version. + """ + return copy.deepcopy(_compute_kwargs(cls)) + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str = ModelConfig.model + served_model_name: Optional[Union[ + str, List[str]]] = ModelConfig.served_model_name + tokenizer: Optional[str] = ModelConfig.tokenizer + hf_config_path: Optional[str] = ModelConfig.hf_config_path + task: TaskOption = ModelConfig.task + skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init + enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds + tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode + trust_remote_code: bool = ModelConfig.trust_remote_code + allowed_local_media_path: str = ModelConfig.allowed_local_media_path + download_dir: Optional[str] = LoadConfig.download_dir + load_format: str = LoadConfig.load_format + config_format: str = ModelConfig.config_format + dtype: ModelDType = ModelConfig.dtype + kv_cache_dtype: CacheDType = CacheConfig.cache_dtype + seed: Optional[int] = ModelConfig.seed + max_model_len: Optional[int] = ModelConfig.max_model_len + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, + "cuda_graph_sizes") + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[ + DistributedExecutorBackend, + Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + # number of P/D disaggregation (or other disaggregation) workers + pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size + tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + data_parallel_size: int = ParallelConfig.data_parallel_size + data_parallel_rank: Optional[int] = None + data_parallel_size_local: Optional[int] = None + data_parallel_address: Optional[str] = None + data_parallel_rpc_port: Optional[int] = None + data_parallel_backend: str = ParallelConfig.data_parallel_backend + enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_eplb: bool = ParallelConfig.enable_eplb + num_redundant_experts: int = ParallelConfig.num_redundant_experts + eplb_window_size: int = ParallelConfig.eplb_window_size + eplb_step_interval: int = ParallelConfig.eplb_step_interval + eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness + max_parallel_loading_workers: Optional[ + int] = ParallelConfig.max_parallel_loading_workers + block_size: Optional[BlockSize] = CacheConfig.block_size + enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching + prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + CacheConfig.prefix_caching_hash_algo + disable_sliding_window: bool = ModelConfig.disable_sliding_window + disable_cascade_attn: bool = ModelConfig.disable_cascade_attn + use_v2_block_manager: bool = True + swap_space: float = CacheConfig.swap_space + cpu_offload_gb: float = CacheConfig.cpu_offload_gb + gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization + max_num_batched_tokens: Optional[ + int] = SchedulerConfig.max_num_batched_tokens + max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills + max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills + long_prefill_token_threshold: int = \ + SchedulerConfig.long_prefill_token_threshold + max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs + max_logprobs: int = ModelConfig.max_logprobs + disable_log_stats: bool = False + revision: Optional[str] = ModelConfig.revision + code_revision: Optional[str] = ModelConfig.code_revision + rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") + rope_theta: Optional[float] = ModelConfig.rope_theta + hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token + hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") + tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision + quantization: Optional[QuantizationMethods] = ModelConfig.quantization + enforce_eager: bool = ModelConfig.enforce_eager + max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture + disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce + # The following three fields are deprecated and will be removed in a future + # release. Setting them will have no effect. Please remove them from your + # configurations. + tokenizer_pool_size: int = TokenizerPoolConfig.pool_size + tokenizer_pool_type: str = TokenizerPoolConfig.pool_type + tokenizer_pool_extra_config: dict = \ + get_field(TokenizerPoolConfig, "extra_config") + limit_mm_per_prompt: dict[str, int] = \ + get_field(MultiModalConfig, "limit_per_prompt") + media_io_kwargs: dict[str, dict[str, + Any]] = get_field(MultiModalConfig, + "media_io_kwargs") + mm_processor_kwargs: Optional[Dict[str, Any]] = \ + MultiModalConfig.mm_processor_kwargs + disable_mm_preprocessor_cache: bool = \ + MultiModalConfig.disable_mm_preprocessor_cache + # LoRA fields + enable_lora: bool = False + enable_lora_bias: bool = LoRAConfig.bias_enabled + max_loras: int = LoRAConfig.max_loras + max_lora_rank: int = LoRAConfig.max_lora_rank + fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras + max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras + lora_target_modules: Optional[List[str]] = LoRAConfig.lora_target_modules + lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype + lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size + long_lora_scaling_factors: Optional[tuple[float, ...]] = \ + LoRAConfig.long_lora_scaling_factors + # PromptAdapter fields + enable_prompt_adapter: bool = False + max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters + max_prompt_adapter_token: int = \ + PromptAdapterConfig.max_prompt_adapter_token + device: Device = DeviceConfig.device + num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps + multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs + ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight + num_gpu_blocks_override: Optional[ + int] = CacheConfig.num_gpu_blocks_override + num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots + model_loader_extra_config: dict = \ + get_field(LoadConfig, "model_loader_extra_config") + ignore_patterns: Optional[Union[str, + List[str]]] = LoadConfig.ignore_patterns + preemption_mode: Optional[str] = SchedulerConfig.preemption_mode + + scheduler_delay_factor: float = SchedulerConfig.delay_factor + enable_chunked_prefill: Optional[ + bool] = SchedulerConfig.enable_chunked_prefill + disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input + + disable_hybrid_kv_cache_manager: bool = ( + SchedulerConfig.disable_hybrid_kv_cache_manager) + + guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend + guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback + guided_decoding_disable_any_whitespace: bool = \ + DecodingConfig.disable_any_whitespace + guided_decoding_disable_additional_properties: bool = \ + DecodingConfig.disable_additional_properties + logits_processor_pattern: Optional[ + str] = ModelConfig.logits_processor_pattern + + speculative_config: Optional[Dict[str, Any]] = None + num_speculative_heads: Optional[int] = None + + qlora_adapter_name_or_path: Optional[str] = None + show_hidden_metrics_for_version: Optional[str] = \ + ObservabilityConfig.show_hidden_metrics_for_version + otlp_traces_endpoint: Optional[str] = \ + ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ + ObservabilityConfig.collect_detailed_traces + disable_async_output_proc: bool = not ModelConfig.use_async_output_proc + scheduling_policy: SchedulerPolicy = SchedulerConfig.policy + scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls + + override_neuron_config: dict[str, Any] = \ + get_field(ModelConfig, "override_neuron_config") + override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + ModelConfig.override_pooler_config + compilation_config: CompilationConfig = \ + get_field(VllmConfig, "compilation_config") + worker_cls: str = ParallelConfig.worker_cls + worker_extension_cls: str = ParallelConfig.worker_extension_cls + + kv_transfer_config: Optional[KVTransferConfig] = None + kv_events_config: Optional[KVEventsConfig] = None + + generation_config: str = ModelConfig.generation_config + enable_sleep_mode: bool = ModelConfig.enable_sleep_mode + override_generation_config: dict[str, Any] = \ + get_field(ModelConfig, "override_generation_config") + model_impl: str = ModelConfig.model_impl + override_attention_dtype: str = ModelConfig.override_attention_dtype + + calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + + additional_config: dict[str, Any] = \ + get_field(VllmConfig, "additional_config") + enable_reasoning: Optional[bool] = None # DEPRECATED + reasoning_parser: str = DecodingConfig.reasoning_backend + + use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load + pt_load_map_location: str = LoadConfig.pt_load_map_location + + enable_multimodal_encoder_data_parallel: bool = \ + ParallelConfig.enable_multimodal_encoder_data_parallel + + def __post_init__(self): + # support `EngineArgs(compilation_config={...})` + # without having to manually construct a + # CompilationConfig object + if isinstance(self.compilation_config, (int, dict)): + self.compilation_config = CompilationConfig.from_cli( + str(self.compilation_config)) + if self.qlora_adapter_name_or_path is not None: + warnings.warn( + "The `qlora_adapter_name_or_path` is deprecated " + "and will be removed in v0.10.0. ", + DeprecationWarning, + stacklevel=2, + ) + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # Model arguments + model_kwargs = get_kwargs(ModelConfig) + model_group = parser.add_argument_group( + title="ModelConfig", + description=ModelConfig.__doc__, + ) + if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + model_group.add_argument("--model", **model_kwargs["model"]) + model_group.add_argument("--task", **model_kwargs["task"]) + model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) + model_group.add_argument("--tokenizer-mode", + **model_kwargs["tokenizer_mode"]) + model_group.add_argument("--trust-remote-code", + **model_kwargs["trust_remote_code"]) + model_group.add_argument("--dtype", **model_kwargs["dtype"]) + model_group.add_argument("--seed", **model_kwargs["seed"]) + model_group.add_argument("--hf-config-path", + **model_kwargs["hf_config_path"]) + model_group.add_argument("--allowed-local-media-path", + **model_kwargs["allowed_local_media_path"]) + model_group.add_argument("--revision", **model_kwargs["revision"]) + model_group.add_argument("--code-revision", + **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", + **model_kwargs["rope_scaling"]) + model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) + model_group.add_argument("--tokenizer-revision", + **model_kwargs["tokenizer_revision"]) + model_group.add_argument("--max-model-len", + **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", + **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", + **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-seq-len-to-capture", + **model_kwargs["max_seq_len_to_capture"]) + model_group.add_argument("--max-logprobs", + **model_kwargs["max_logprobs"]) + model_group.add_argument("--disable-sliding-window", + **model_kwargs["disable_sliding_window"]) + model_group.add_argument("--disable-cascade-attn", + **model_kwargs["disable_cascade_attn"]) + model_group.add_argument("--skip-tokenizer-init", + **model_kwargs["skip_tokenizer_init"]) + model_group.add_argument("--enable-prompt-embeds", + **model_kwargs["enable_prompt_embeds"]) + model_group.add_argument("--served-model-name", + **model_kwargs["served_model_name"]) + # This one is a special case because it is the + # opposite of ModelConfig.use_async_output_proc + model_group.add_argument( + "--disable-async-output-proc", + action="store_true", + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + model_group.add_argument("--config-format", + choices=[f.value for f in ConfigFormat], + **model_kwargs["config_format"]) + # This one is a special case because it can bool + # or str. TODO: Handle this in get_kwargs + model_group.add_argument("--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"]) + model_group.add_argument("--hf-overrides", + **model_kwargs["hf_overrides"]) + model_group.add_argument("--override-neuron-config", + **model_kwargs["override_neuron_config"]) + model_group.add_argument("--override-pooler-config", + **model_kwargs["override_pooler_config"]) + model_group.add_argument("--logits-processor-pattern", + **model_kwargs["logits_processor_pattern"]) + model_group.add_argument("--generation-config", + **model_kwargs["generation_config"]) + model_group.add_argument("--override-generation-config", + **model_kwargs["override_generation_config"]) + model_group.add_argument("--enable-sleep-mode", + **model_kwargs["enable_sleep_mode"]) + model_group.add_argument("--model-impl", + choices=[f.value for f in ModelImpl], + **model_kwargs["model_impl"]) + model_group.add_argument("--override-attention-dtype", + **model_kwargs["override_attention_dtype"]) + + # Model loading arguments + load_kwargs = get_kwargs(LoadConfig) + load_group = parser.add_argument_group( + title="LoadConfig", + description=LoadConfig.__doc__, + ) + load_group.add_argument("--load-format", + choices=[f.value for f in LoadFormat], + **load_kwargs["load_format"]) + load_group.add_argument("--download-dir", + **load_kwargs["download_dir"]) + load_group.add_argument("--model-loader-extra-config", + **load_kwargs["model_loader_extra_config"]) + load_group.add_argument("--ignore-patterns", + **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", + **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--qlora-adapter-name-or-path", + type=str, + default=None, + help="The `--qlora-adapter-name-or-path` has no effect, do not set" + " it, and it will be removed in v0.10.0.", + deprecated=True, + ) + load_group.add_argument('--pt-load-map-location', + **load_kwargs["pt_load_map_location"]) + + # Guided decoding arguments + guided_decoding_kwargs = get_kwargs(DecodingConfig) + guided_decoding_group = parser.add_argument_group( + title="DecodingConfig", + description=DecodingConfig.__doc__, + ) + guided_decoding_group.add_argument("--guided-decoding-backend", + **guided_decoding_kwargs["backend"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-fallback", + **guided_decoding_kwargs["disable_fallback"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-any-whitespace", + **guided_decoding_kwargs["disable_any_whitespace"]) + guided_decoding_group.add_argument( + "--guided-decoding-disable-additional-properties", + **guided_decoding_kwargs["disable_additional_properties"]) + guided_decoding_group.add_argument( + "--enable-reasoning", + action=argparse.BooleanOptionalAction, + deprecated=True, + help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " + "of v0.9.0. Use `--reasoning-parser` to specify the reasoning " + "parser backend instead. This flag (`--enable-reasoning`) will be " + "removed in v0.10.0. When `--reasoning-parser` is specified, " + "reasoning mode is automatically enabled.") + guided_decoding_group.add_argument( + "--reasoning-parser", + # This choices is a special case because it's not static + choices=list(ReasoningParserManager.reasoning_parsers), + **guided_decoding_kwargs["reasoning_backend"]) + + # Parallel arguments + parallel_kwargs = get_kwargs(ParallelConfig) + parallel_group = parser.add_argument_group( + title="ParallelConfig", + description=ParallelConfig.__doc__, + ) + parallel_group.add_argument( + "--distributed-executor-backend", + **parallel_kwargs["distributed_executor_backend"]) + parallel_group.add_argument( + "--pipeline-parallel-size", "-pp", + **parallel_kwargs["pipeline_parallel_size"]) + parallel_group.add_argument("--tensor-parallel-size", "-tp", + **parallel_kwargs["tensor_parallel_size"]) + parallel_group.add_argument("--data-parallel-size", "-dp", + **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument( + '--data-parallel-rank', + '-dpn', + type=int, + help='Data parallel rank of this instance. ' + 'When set, enables external load balancer mode.') + parallel_group.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + help='Number of data parallel replicas ' + 'to run on this node.') + parallel_group.add_argument('--data-parallel-address', + '-dpa', + type=str, + help='Address of data parallel cluster ' + 'head-node.') + parallel_group.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + help='Port for data parallel RPC ' + 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + default='mp', + help='Backend for data parallel, either ' + '"mp" or "ray".') + parallel_group.add_argument( + "--enable-expert-parallel", + **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-eplb", + **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--num-redundant-experts", + **parallel_kwargs["num_redundant_experts"]) + parallel_group.add_argument("--eplb-window-size", + **parallel_kwargs["eplb_window_size"]) + parallel_group.add_argument("--eplb-step-interval", + **parallel_kwargs["eplb_step_interval"]) + parallel_group.add_argument("--eplb-log-balancedness", + **parallel_kwargs["eplb_log_balancedness"]) + parallel_group.add_argument( + "--max-parallel-loading-workers", + **parallel_kwargs["max_parallel_loading_workers"]) + parallel_group.add_argument( + "--ray-workers-use-nsight", + **parallel_kwargs["ray_workers_use_nsight"]) + parallel_group.add_argument( + "--disable-custom-all-reduce", + **parallel_kwargs["disable_custom_all_reduce"]) + parallel_group.add_argument("--worker-cls", + **parallel_kwargs["worker_cls"]) + parallel_group.add_argument("--worker-extension-cls", + **parallel_kwargs["worker_extension_cls"]) + parallel_group.add_argument( + "--enable-multimodal-encoder-data-parallel", + **parallel_kwargs["enable_multimodal_encoder_data_parallel"]) + + # KV cache arguments + cache_kwargs = get_kwargs(CacheConfig) + cache_group = parser.add_argument_group( + title="CacheConfig", + description=CacheConfig.__doc__, + ) + cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) + cache_group.add_argument("--gpu-memory-utilization", + **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) + cache_group.add_argument("--kv-cache-dtype", + **cache_kwargs["cache_dtype"]) + cache_group.add_argument("--num-gpu-blocks-override", + **cache_kwargs["num_gpu_blocks_override"]) + cache_group.add_argument("--enable-prefix-caching", + **cache_kwargs["enable_prefix_caching"]) + cache_group.add_argument("--prefix-caching-hash-algo", + **cache_kwargs["prefix_caching_hash_algo"]) + cache_group.add_argument("--cpu-offload-gb", + **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument("--calculate-kv-scales", + **cache_kwargs["calculate_kv_scales"]) + + # Tokenizer arguments + tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) + tokenizer_group = parser.add_argument_group( + title="TokenizerPoolConfig", + description=TokenizerPoolConfig.__doc__, + ) + tokenizer_group.add_argument("--tokenizer-pool-size", + **tokenizer_kwargs["pool_size"]) + tokenizer_group.add_argument("--tokenizer-pool-type", + **tokenizer_kwargs["pool_type"]) + tokenizer_group.add_argument("--tokenizer-pool-extra-config", + **tokenizer_kwargs["extra_config"]) + + # Multimodal related configs + multimodal_kwargs = get_kwargs(MultiModalConfig) + multimodal_group = parser.add_argument_group( + title="MultiModalConfig", + description=MultiModalConfig.__doc__, + ) + multimodal_group.add_argument("--limit-mm-per-prompt", + **multimodal_kwargs["limit_per_prompt"]) + multimodal_group.add_argument("--media-io-kwargs", + **multimodal_kwargs["media_io_kwargs"]) + multimodal_group.add_argument( + "--mm-processor-kwargs", + **multimodal_kwargs["mm_processor_kwargs"]) + multimodal_group.add_argument( + "--disable-mm-preprocessor-cache", + **multimodal_kwargs["disable_mm_preprocessor_cache"]) + + # LoRA related configs + lora_kwargs = get_kwargs(LoRAConfig) + lora_group = parser.add_argument_group( + title="LoRAConfig", + description=LoRAConfig.__doc__, + ) + lora_group.add_argument( + "--enable-lora", + action=argparse.BooleanOptionalAction, + help="If True, enable handling of LoRA adapters.") + lora_group.add_argument("--enable-lora-bias", + **lora_kwargs["bias_enabled"]) + lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) + lora_group.add_argument("--max-lora-rank", + **lora_kwargs["max_lora_rank"]) + lora_group.add_argument('--lora-target-modules', + **lora_kwargs["lora_target_modules"]) + lora_group.add_argument("--lora-extra-vocab-size", + **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument( + "--lora-dtype", + **lora_kwargs["lora_dtype"], + ) + lora_group.add_argument("--long-lora-scaling-factors", + **lora_kwargs["long_lora_scaling_factors"]) + lora_group.add_argument("--max-cpu-loras", + **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument("--fully-sharded-loras", + **lora_kwargs["fully_sharded_loras"]) + + # PromptAdapter related configs + prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) + prompt_adapter_group = parser.add_argument_group( + title="PromptAdapterConfig", + description=PromptAdapterConfig.__doc__, + ) + prompt_adapter_group.add_argument( + "--enable-prompt-adapter", + action=argparse.BooleanOptionalAction, + help="If True, enable handling of PromptAdapters.") + prompt_adapter_group.add_argument( + "--max-prompt-adapters", + **prompt_adapter_kwargs["max_prompt_adapters"]) + prompt_adapter_group.add_argument( + "--max-prompt-adapter-token", + **prompt_adapter_kwargs["max_prompt_adapter_token"]) + + # Device arguments + device_kwargs = get_kwargs(DeviceConfig) + device_group = parser.add_argument_group( + title="DeviceConfig", + description=DeviceConfig.__doc__, + ) + device_group.add_argument("--device", + **device_kwargs["device"], + deprecated=True) + + # Speculative arguments + speculative_group = parser.add_argument_group( + title="SpeculativeConfig", + description=SpeculativeConfig.__doc__, + ) + speculative_group.add_argument( + "--speculative-config", + type=json.loads, + default=None, + help="The configurations for speculative decoding. Should be a " + "JSON string.") + parser.add_argument( + '--num-speculative-heads', + type=int, + default=EngineArgs.num_speculative_heads, + help='The number of speculative heads to sample from ' + 'the draft model in speculative decoding.') + + # Observability arguments + observability_kwargs = get_kwargs(ObservabilityConfig) + observability_group = parser.add_argument_group( + title="ObservabilityConfig", + description=ObservabilityConfig.__doc__, + ) + observability_group.add_argument( + "--show-hidden-metrics-for-version", + **observability_kwargs["show_hidden_metrics_for_version"]) + observability_group.add_argument( + "--otlp-traces-endpoint", + **observability_kwargs["otlp_traces_endpoint"]) + # TODO: generalise this special case + choices = observability_kwargs["collect_detailed_traces"]["choices"] + metavar = f"{{{','.join(choices)}}}" + observability_kwargs["collect_detailed_traces"]["metavar"] = metavar + observability_kwargs["collect_detailed_traces"]["choices"] += [ + ",".join(p) + for p in permutations(get_args(DetailedTraceModules), r=2) + ] + observability_group.add_argument( + "--collect-detailed-traces", + **observability_kwargs["collect_detailed_traces"]) + + # Scheduler arguments + scheduler_kwargs = get_kwargs(SchedulerConfig) + scheduler_group = parser.add_argument_group( + title="SchedulerConfig", + description=SchedulerConfig.__doc__, + ) + scheduler_group.add_argument( + "--max-num-batched-tokens", + **scheduler_kwargs["max_num_batched_tokens"]) + scheduler_group.add_argument("--max-num-seqs", + **scheduler_kwargs["max_num_seqs"]) + scheduler_group.add_argument( + "--max-num-partial-prefills", + **scheduler_kwargs["max_num_partial_prefills"]) + scheduler_group.add_argument( + "--max-long-partial-prefills", + **scheduler_kwargs["max_long_partial_prefills"]) + scheduler_group.add_argument('--cuda-graph-sizes', + **scheduler_kwargs["cuda_graph_sizes"]) + scheduler_group.add_argument( + "--long-prefill-token-threshold", + **scheduler_kwargs["long_prefill_token_threshold"]) + scheduler_group.add_argument("--num-lookahead-slots", + **scheduler_kwargs["num_lookahead_slots"]) + scheduler_group.add_argument("--scheduler-delay-factor", + **scheduler_kwargs["delay_factor"]) + scheduler_group.add_argument("--preemption-mode", + **scheduler_kwargs["preemption_mode"]) + scheduler_group.add_argument("--num-scheduler-steps", + **scheduler_kwargs["num_scheduler_steps"]) + scheduler_group.add_argument( + "--multi-step-stream-outputs", + **scheduler_kwargs["multi_step_stream_outputs"]) + scheduler_group.add_argument("--scheduling-policy", + **scheduler_kwargs["policy"]) + scheduler_group.add_argument( + "--enable-chunked-prefill", + **scheduler_kwargs["enable_chunked_prefill"]) + scheduler_group.add_argument( + "--disable-chunked-mm-input", + **scheduler_kwargs["disable_chunked_mm_input"]) + scheduler_group.add_argument("--scheduler-cls", + **scheduler_kwargs["scheduler_cls"]) + scheduler_group.add_argument( + "--disable-hybrid-kv-cache-manager", + **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) + + # vLLM arguments + vllm_kwargs = get_kwargs(VllmConfig) + vllm_group = parser.add_argument_group( + title="VllmConfig", + description=VllmConfig.__doc__, + ) + vllm_group.add_argument("--kv-transfer-config", + **vllm_kwargs["kv_transfer_config"]) + vllm_group.add_argument('--kv-events-config', + **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--compilation-config", "-O", + **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--additional-config", + **vllm_kwargs["additional_config"]) + + # Other arguments + parser.add_argument('--use-v2-block-manager', + action='store_true', + default=True, + deprecated=True, + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') + parser.add_argument('--disable-log-stats', + action='store_true', + help='Disable logging statistics.') + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_model_config(self) -> ModelConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # NOTE: This is to allow model loading from S3 in CI + if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == LoadFormat.AUTO): # noqa: E501 + self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" + self.load_format = LoadFormat.RUNAI_STREAMER + + return ModelConfig( + model=self.model, + hf_config_path=self.hf_config_path, + task=self.task, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + hf_token=self.hf_token, + hf_overrides=self.hf_overrides, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + enforce_eager=self.enforce_eager, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + disable_cascade_attn=self.disable_cascade_attn, + skip_tokenizer_init=self.skip_tokenizer_init, + enable_prompt_embeds=self.enable_prompt_embeds, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + media_io_kwargs=self.media_io_kwargs, + use_async_output_proc=not self.disable_async_output_proc, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, + override_neuron_config=self.override_neuron_config, + override_pooler_config=self.override_pooler_config, + logits_processor_pattern=self.logits_processor_pattern, + generation_config=self.generation_config, + override_generation_config=self.override_generation_config, + enable_sleep_mode=self.enable_sleep_mode, + model_impl=self.model_impl, + override_attention_dtype=self.override_attention_dtype, + enable_chunked_prefill=self.enable_chunked_prefill, + ) + + def create_load_config(self) -> LoadConfig: + + if self.quantization == "bitsandbytes": + self.load_format = "bitsandbytes" + + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + use_tqdm_on_load=self.use_tqdm_on_load, + pt_load_map_location=self.pt_load_map_location, + ) + + def create_speculative_config( + self, + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + enable_chunked_prefill: bool, + disable_log_stats: bool, + ) -> Optional["SpeculativeConfig"]: + """Initializes and returns a SpeculativeConfig object based on + `speculative_config`. + + This function utilizes `speculative_config` to create a + SpeculativeConfig object. The `speculative_config` can either be + provided as a JSON string input via CLI arguments or directly as a + dictionary from the engine. + """ + if self.speculative_config is None: + return None + + # Note(Shangming): These parameters are not obtained from the cli arg + # '--speculative-config' and must be passed in when creating the engine + # config. + self.speculative_config.update({ + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + }) + speculative_config = SpeculativeConfig.from_dict( + self.speculative_config) + + return speculative_config + + def create_engine_config( + self, + usage_context: Optional[UsageContext] = None, + ) -> VllmConfig: + """ + Create the VllmConfig. + + NOTE: for autoselection of V0 vs V1 engine, we need to + create the ModelConfig first, since ModelConfig's attrs + (e.g. the model arch) are needed to make the decision. + + This function set VLLM_USE_V1=X if VLLM_USE_V1 is + unspecified by the user. + + If VLLM_USE_V1 is specified by the user but the VllmConfig + is incompatible, we raise an error. + """ + from vllm.platforms import current_platform + current_platform.pre_register_and_update() + + device_config = DeviceConfig( + device=cast(Device, current_platform.device_type)) + model_config = self.create_model_config() + + # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" + # and fall back to V0 for experimental or unsupported features. + # * If VLLM_USE_V1=1, we enable V1 for supported + experimental + # features and raise error for unsupported features. + # * If VLLM_USE_V1=0, we disable V1. + use_v1 = False + try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1") + if try_v1 and self._is_v1_supported_oracle(model_config): + use_v1 = True + + # If user explicitly set VLLM_USE_V1, sanity check we respect it. + if envs.is_set("VLLM_USE_V1"): + assert use_v1 == envs.VLLM_USE_V1 + # Otherwise, set the VLLM_USE_V1 variable globally. + else: + envs.set_vllm_use_v1(use_v1) + + # Set default arguments for V0 or V1 Engine. + if use_v1: + self._set_default_args_v1(usage_context, model_config) + else: + self._set_default_args_v0(model_config) + + assert self.enable_chunked_prefill is not None + + if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: + assert self.enforce_eager, ( + "Cuda graph is not supported with DualChunkFlashAttention. " + "To run the model in eager mode, set 'enforce_eager=True' " + "or use '--enforce-eager' in the CLI.") + assert current_platform.is_cuda() or current_platform.is_rocm(), ( + "DualChunkFlashAttention is supported on CUDA/ROCM platform.") + assert not use_v1, ( + "DualChunkFlashAttention is not supported on V1 engine. " + "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") + + cache_config = CacheConfig( + block_size=self.block_size, + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + prefix_caching_hash_algo=self.prefix_caching_hash_algo, + cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, + ) + + # Get the current placement group if Ray is initialized and + # we are in a Ray actor. If so, then the placement group will be + # passed to spawned processes. + placement_group = None + if is_in_ray_actor(): + import ray + + # This call initializes Ray automatically if it is not initialized, + # but we should not do this here. + placement_group = ray.util.get_current_placement_group() + + data_parallel_external_lb = self.data_parallel_rank is not None + if data_parallel_external_lb: + assert self.data_parallel_size_local in (1, None), ( + "data_parallel_size_local must be 1 when data_parallel_rank " + "is set") + data_parallel_size_local = 1 + elif self.data_parallel_size_local is not None: + data_parallel_size_local = self.data_parallel_size_local + else: + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size + + # DP address, used in multi-node case for torch distributed group + # and ZMQ sockets. + if self.data_parallel_address is None: + if self.data_parallel_backend == "ray": + host_ip = get_ip() + logger.info( + "Using host IP %s as ray-based data parallel address", + host_ip) + data_parallel_address = host_ip + else: + assert self.data_parallel_backend == "mp", ( + "data_parallel_backend can only be ray or mp, got %s", + self.data_parallel_backend) + data_parallel_address = ParallelConfig.data_parallel_master_ip + else: + data_parallel_address = self.data_parallel_address + + # This port is only used when there are remote data parallel engines, + # otherwise the local IPC transport is used. + data_parallel_rpc_port = self.data_parallel_rpc_port if ( + self.data_parallel_rpc_port + is not None) else ParallelConfig.data_parallel_rpc_port + + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + data_parallel_size=self.data_parallel_size, + data_parallel_rank=self.data_parallel_rank or 0, + data_parallel_external_lb=data_parallel_external_lb, + data_parallel_size_local=data_parallel_size_local, + data_parallel_master_ip=data_parallel_address, + data_parallel_rpc_port=data_parallel_rpc_port, + data_parallel_backend=self.data_parallel_backend, + enable_expert_parallel=self.enable_expert_parallel, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.num_redundant_experts, + eplb_window_size=self.eplb_window_size, + eplb_step_interval=self.eplb_step_interval, + eplb_log_balancedness=self.eplb_log_balancedness, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + ray_workers_use_nsight=self.ray_workers_use_nsight, + placement_group=placement_group, + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + worker_extension_cls=self.worker_extension_cls, + enable_multimodal_encoder_data_parallel=self. + enable_multimodal_encoder_data_parallel, + ) + + speculative_config = self.create_speculative_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + enable_chunked_prefill=self.enable_chunked_prefill, + disable_log_stats=self.disable_log_stats, + ) + + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + if self.num_scheduler_steps > 1: + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") + from vllm.platforms import current_platform + if current_platform.is_cpu(): + logger.warning("Multi-Step (--num-scheduler-steps > 1) is " + "currently not supported for CPUs and has been " + "disabled.") + self.num_scheduler_steps = 1 + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + + scheduler_config = SchedulerConfig( + runner_type=model_config.runner_type, + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + cuda_graph_sizes=self.cuda_graph_sizes, + num_lookahead_slots=num_lookahead_slots, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + disable_chunked_mm_input=self.disable_chunked_mm_input, + is_multimodal_model=model_config.is_multimodal_model, + preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), + policy=self.scheduling_policy, + scheduler_cls=self.scheduler_cls, + max_num_partial_prefills=self.max_num_partial_prefills, + max_long_partial_prefills=self.max_long_partial_prefills, + long_prefill_token_threshold=self.long_prefill_token_threshold, + disable_hybrid_kv_cache_manager=self. + disable_hybrid_kv_cache_manager, + ) + + lora_config = LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None, + lora_target_modules=self.lora_target_modules) if self.enable_lora else None + + # bitsandbytes pre-quantized model need a specific model loader + if model_config.quantization == "bitsandbytes": + self.quantization = self.load_format = "bitsandbytes" + + load_config = self.create_load_config() + + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + + decoding_config = DecodingConfig( + backend=self.guided_decoding_backend, + disable_fallback=self.guided_decoding_disable_fallback, + disable_any_whitespace=self.guided_decoding_disable_any_whitespace, + disable_additional_properties=\ + self.guided_decoding_disable_additional_properties, + reasoning_backend=self.reasoning_parser + ) + + observability_config = ObservabilityConfig( + show_hidden_metrics_for_version=self. + show_hidden_metrics_for_version, + otlp_traces_endpoint=self.otlp_traces_endpoint, + collect_detailed_traces=self.collect_detailed_traces, + ) + + config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, + compilation_config=self.compilation_config, + kv_transfer_config=self.kv_transfer_config, + kv_events_config=self.kv_events_config, + additional_config=self.additional_config, + ) + + return config + + def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: + """Oracle for whether to use V0 or V1 Engine by default.""" + + ############################################################# + # Unsupported Feature Flags on V1. + + if self.load_format == LoadFormat.SHARDED_STATE.value: + _raise_or_fallback( + feature_name=f"--load_format {self.load_format}", + recommend_to_remove=False) + return False + + if (self.logits_processor_pattern + != EngineArgs.logits_processor_pattern): + _raise_or_fallback(feature_name="--logits-processor-pattern", + recommend_to_remove=False) + return False + + if self.preemption_mode != SchedulerConfig.preemption_mode: + _raise_or_fallback(feature_name="--preemption-mode", + recommend_to_remove=True) + return False + + if (self.disable_async_output_proc + != EngineArgs.disable_async_output_proc): + _raise_or_fallback(feature_name="--disable-async-output-proc", + recommend_to_remove=True) + return False + + if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps: + _raise_or_fallback(feature_name="--num-scheduler-steps", + recommend_to_remove=True) + return False + + if self.scheduler_delay_factor != SchedulerConfig.delay_factor: + _raise_or_fallback(feature_name="--scheduler-delay-factor", + recommend_to_remove=True) + return False + + if self.guided_decoding_backend not in get_args( + GuidedDecodingBackendV1): + _raise_or_fallback( + feature_name= + f"--guided-decoding-backend={self.guided_decoding_backend}", + recommend_to_remove=False) + return False + + # Need at least Ampere for now (FA support required). + # Skip this check if we are running on a non-GPU platform, + # or if the device capability is not available + # (e.g. in a Ray actor without GPUs). + from vllm.platforms import current_platform + if (current_platform.is_cuda() + and current_platform.get_device_capability() + and current_platform.get_device_capability().major < 8): + _raise_or_fallback(feature_name="Compute Capability < 8.0", + recommend_to_remove=False) + return False + + # No Fp8 KV cache so far. + if self.kv_cache_dtype != "auto": + fp8_attention = self.kv_cache_dtype.startswith("fp8") + will_use_fa = ( + current_platform.is_cuda() + and not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if current_platform.is_rocm(): + supported = True + elif fp8_attention and will_use_fa: + from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8) + supported = flash_attn_supports_fp8() + + int8_attention = self.kv_cache_dtype.startswith("int8") + if int8_attention: + supported = True + + if not supported: + _raise_or_fallback(feature_name="--kv-cache-dtype", + recommend_to_remove=False) + return False + + # No Prompt Adapter so far. + if self.enable_prompt_adapter: + _raise_or_fallback(feature_name="--enable-prompt-adapter", + recommend_to_remove=False) + return False + + # No text embedding inputs so far. + if self.enable_prompt_embeds: + _raise_or_fallback(feature_name="--enable-prompt-embeds", + recommend_to_remove=False) + return False + + # No Mamba or Encoder-Decoder so far. + if not model_config.is_v1_compatible: + _raise_or_fallback(feature_name=model_config.architectures, + recommend_to_remove=False) + return False + + # V1 mamba models are unoptimized. + if model_config.has_inner_state and _warn_or_fallback( + feature_name="Mamba"): + return False + + # No Concurrent Partial Prefills so far. + if (self.max_num_partial_prefills + != SchedulerConfig.max_num_partial_prefills + or self.max_long_partial_prefills + != SchedulerConfig.max_long_partial_prefills): + _raise_or_fallback(feature_name="Concurrent Partial Prefill", + recommend_to_remove=False) + return False + + # No OTLP observability so far. + if (self.otlp_traces_endpoint or self.collect_detailed_traces): + _raise_or_fallback(feature_name="--otlp-traces-endpoint", + recommend_to_remove=False) + return False + + # V1 supports N-gram, Medusa, and Eagle speculative decoding. + is_ngram_enabled = False + is_eagle_enabled = False + is_medusa_enabled = False + if self.speculative_config is not None: + # This is supported but experimental (handled below). + speculative_method = self.speculative_config.get("method") + if speculative_method: + if speculative_method in ("ngram", "[ngram]"): + is_ngram_enabled = True + elif speculative_method == "medusa": + is_medusa_enabled = True + elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): + is_eagle_enabled = True + else: + speculative_model = self.speculative_config.get("model") + if speculative_model in ("ngram", "[ngram]"): + is_ngram_enabled = True + if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): + # Other speculative decoding methods are not supported yet. + _raise_or_fallback(feature_name="Speculative Decoding", + recommend_to_remove=False) + return False + + # No XFormers so far. + V1_BACKENDS = [ + "FLASH_ATTN_VLLM_V1", + "FLASH_ATTN", + "PALLAS", + "PALLAS_VLLM_V1", + "TRITON_ATTN_VLLM_V1", + "TRITON_MLA", + "CUTLASS_MLA_VLLM_V1", + "FLASHMLA", + "FLASHINFER", + "FLASHINFER_VLLM_V1", + "ROCM_AITER_MLA", + "TORCH_SDPA_VLLM_V1", + "FLEX_ATTENTION", + ] + if (envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" + _raise_or_fallback(feature_name=name, recommend_to_remove=True) + return False + + # Platforms must decide if they can support v1 for this model + if not current_platform.supports_v1(model_config=model_config): + _raise_or_fallback( + feature_name=f"device type={current_platform.device_type}", + recommend_to_remove=False) + return False + ############################################################# + # Experimental Features - allow users to opt in. + + # Signal Handlers requires running in main thread. + if (threading.current_thread() != threading.main_thread() + and _warn_or_fallback("Engine in background thread")): + return False + + if (self.pipeline_parallel_size > 1 + and self.distributed_executor_backend + not in (ParallelConfig.distributed_executor_backend, "ray", + "mp", "external_launcher")): + name = "Pipeline Parallelism without Ray distributed executor " \ + "or multiprocessing executor or external launcher" + _raise_or_fallback(feature_name=name, recommend_to_remove=False) + return False + + # The platform may be supported on V1, but off by default for now. + if not current_platform.default_v1( # noqa: SIM103 + model_config=model_config) and _warn_or_fallback( + current_platform.device_name): + return False + + if (current_platform.is_cpu() + and model_config.get_sliding_window() is not None): + _raise_or_fallback(feature_name="sliding window (CPU backend)", + recommend_to_remove=False) + return False + + ############################################################# + + return True + + def _set_default_args_v0(self, model_config: ModelConfig) -> None: + """Set Default Arguments for V0 Engine.""" + + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + if self.enable_chunked_prefill is None: + # Chunked prefill not supported for Multimodal or MLA in V0. + if model_config.is_multimodal_model or model_config.use_mla: + self.enable_chunked_prefill = False + + # Enable chunked prefill by default for long context (> 32K) + # models to avoid OOM errors in initial memory profiling phase. + elif use_long_context: + from vllm.platforms import current_platform + is_gpu = current_platform.is_cuda() + use_sliding_window = (model_config.get_sliding_window() + is not None) + use_spec_decode = self.speculative_config is not None + + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter + and model_config.runner_type != "pooling"): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models " + "with max_model_len > 32K. Chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable by launching " + "with --enable-chunked-prefill=False.") + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause" + "OOM during the initial memory profiling phase, or result " + "in low performance due to small KV cache size. Consider " + "setting --max-model-len to a smaller value.", max_model_len) + elif (self.enable_chunked_prefill + and model_config.runner_type == "pooling"): + msg = "Chunked prefill is not supported for pooling models" + raise ValueError(msg) + + # if using prefix caching, we must set a hash algo + if self.enable_prefix_caching: + # Disable prefix caching for multimodal models for VLLM_V0. + if model_config.is_multimodal_model: + logger.warning( + "--enable-prefix-caching is not supported for multimodal " + "models in V0 and has been disabled.") + self.enable_prefix_caching = False + + # VLLM_V0 only supports builtin hash algo for prefix caching. + if self.prefix_caching_hash_algo == "sha256": + raise ValueError( + "sha256 is not supported for prefix caching in V0 engine. " + "Please use 'builtin'.") + + # Set max_num_seqs to 256 for VLLM_V0. + if self.max_num_seqs is None: + self.max_num_seqs = 256 + + def _set_default_args_v1(self, usage_context: UsageContext, + model_config: ModelConfig) -> None: + """Set Default Arguments for V1 Engine.""" + + # V1 always uses chunked prefills and prefix caching + # for non-pooling tasks. + # For pooling tasks the default is False + if model_config.runner_type != "pooling": + self.enable_chunked_prefill = True + if model_config.enable_chunked_prefill is not None and \ + model_config.enable_chunked_prefill is False: + self.enable_chunked_prefill = False + if self.enable_prefix_caching is None: + self.enable_prefix_caching = True + else: + + pooling_type = model_config.pooler_config.pooling_type + + # TODO: when encoder models are supported we'll have to + # check for causal attention here. + incremental_prefill_supported = (pooling_type is not None and + pooling_type.lower() == "last") + + action = "Enabling" if \ + incremental_prefill_supported else "Disabling" + + if model_config.enable_chunked_prefill is not None and \ + model_config.enable_chunked_prefill is False: + self.enable_chunked_prefill = False + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = incremental_prefill_supported + logger.info("(%s) chunked prefill by default", action) + if self.enable_prefix_caching is None: + self.enable_prefix_caching = incremental_prefill_supported + logger.info("(%s) prefix caching by default", action) + + if not self.enable_chunked_prefill: + self.max_num_batched_tokens = model_config.max_model_len + + # V1 should use the new scheduler by default. + # Swap it only if this arg is set to the original V0 default + if self.scheduler_cls == EngineArgs.scheduler_cls: + self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler" + + # When no user override, set the default values based on the usage + # context. + # Use different default values for different hardware. + + # Try to query the device name on the current platform. If it fails, + # it may be because the platform that imports vLLM is not the same + # as the platform that vLLM is running on (e.g. the case of scaling + # vLLM with Ray) and has no GPUs. In this case we use the default + # values for non-H100/H200 GPUs. + from vllm.platforms import current_platform + try: + device_memory = current_platform.get_device_total_memory() + device_name = current_platform.get_device_name().lower() + except Exception: + # This is only used to set default_max_num_batched_tokens + device_memory = 0 + + # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces + # throughput, see PR #17885 for more details. + # So here we do an extra device name check to prevent such regression. + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: + # For GPUs like H100 and MI300x, use larger default values. + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 16384, + UsageContext.OPENAI_API_SERVER: 8192, + } + default_max_num_seqs = { + UsageContext.LLM_CLASS: 1024, + UsageContext.OPENAI_API_SERVER: 1024, + } + else: + # TODO(woosuk): Tune the default values for other hardware. + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 8192, + UsageContext.OPENAI_API_SERVER: 2048, + } + default_max_num_seqs = { + UsageContext.LLM_CLASS: 256, + UsageContext.OPENAI_API_SERVER: 256, + } + + # tpu specific default values. + if current_platform.is_tpu(): + default_max_num_batched_tokens_tpu = { + UsageContext.LLM_CLASS: { + 'V6E': 2048, + 'V5E': 1024, + 'V5P': 512, + }, + UsageContext.OPENAI_API_SERVER: { + 'V6E': 1024, + 'V5E': 512, + 'V5P': 256, + } + } + + # cpu specific default values. + if current_platform.is_cpu(): + default_max_num_batched_tokens = { + UsageContext.LLM_CLASS: 4096, + UsageContext.OPENAI_API_SERVER: 2048, + } + default_max_num_seqs = { + UsageContext.LLM_CLASS: 128, + UsageContext.OPENAI_API_SERVER: 32, + } + + use_context_value = usage_context.value if usage_context else None + if (self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens): + if current_platform.is_tpu(): + chip_name = current_platform.get_device_name() + if chip_name in default_max_num_batched_tokens_tpu[ + usage_context]: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens_tpu[ + usage_context][chip_name] + else: + self.max_num_batched_tokens = \ + default_max_num_batched_tokens[usage_context] + else: + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context] + logger.debug( + "Setting max_num_batched_tokens to %d for %s usage context.", + self.max_num_batched_tokens, use_context_value) + + if (self.max_num_seqs is None + and usage_context in default_max_num_seqs): + self.max_num_seqs = default_max_num_seqs[usage_context] + + logger.debug("Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, use_context_value) + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + disable_log_requests: bool = False + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser, + async_args_only: bool = False) -> FlexibleArgumentParser: + # Initialize plugin to update the parser, for example, The plugin may + # adding a new kind of quantization method to --quantization argument or + # a new device to --device argument. + load_general_plugins() + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--disable-log-requests', + action='store_true', + help='Disable logging requests.') + from vllm.platforms import current_platform + current_platform.pre_register_and_update(parser) + return parser + + +def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): + # if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + if envs.VLLM_USE_V1: + raise NotImplementedError( + f"VLLM_USE_V1=1 is not supported with {feature_name}.") + msg = f"{feature_name} is not supported by the V1 Engine. " + msg += "Falling back to V0. " + if recommend_to_remove: + msg += f"We recommend to remove {feature_name} from your config " + msg += "in favor of the V1 Engine." + logger.warning(msg) + + +def _warn_or_fallback(feature_name: str) -> bool: + # if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + if envs.VLLM_USE_V1: + logger.warning( + "Detected VLLM_USE_V1=1 with %s. Usage should " + "be considered experimental. Please report any " + "issues on Github.", feature_name) + should_exit = False + else: + logger.info( + "%s is experimental on VLLM_USE_V1=1. " + "Falling back to V0 Engine.", feature_name) + should_exit = True + return should_exit + + +def human_readable_int(value): + """Parse human-readable integers like '1k', '2M', etc. + Including decimal values with decimal multipliers. + + Examples: + - '1k' -> 1,000 + - '1K' -> 1,024 + - '25.6k' -> 25,600 + """ + value = value.strip() + match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + if match: + decimal_multiplier = { + 'k': 10**3, + 'm': 10**6, + 'g': 10**9, + } + binary_multiplier = { + 'K': 2**10, + 'M': 2**20, + 'G': 2**30, + } + + number, suffix = match.groups() + if suffix in decimal_multiplier: + mult = decimal_multiplier[suffix] + return int(float(number) * mult) + elif suffix in binary_multiplier: + mult = binary_multiplier[suffix] + # Do not allow decimals with binary multipliers + try: + return int(number) * mult + except ValueError as e: + raise argparse.ArgumentTypeError("Decimals are not allowed " \ + f"with binary suffixes like {suffix}. Did you mean to use " \ + f"{number}{suffix.lower()} instead?") from e + + # Regular plain number. + return int(value) + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(FlexibleArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(), + async_args_only=True) \ No newline at end of file diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py new file mode 100644 index 0000000..3d7d280 --- /dev/null +++ b/vllm/engine/async_llm_engine.py @@ -0,0 +1,1200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import copy +import time +import weakref +from functools import partial +from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, Type, Union) +from weakref import ReferenceType + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.core.scheduler import SchedulerOutputs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_timeout import asyncio_timeout +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.protocol import EngineClient +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import ExecuteModelRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Device, weak_bind + +logger = init_logger(__name__) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _log_task_completion(task: asyncio.Task, + error_callback: Callable[[Exception], None]) -> None: + """This function is only intended for the `engine.run_engine_loop()` task. + + In particular, that task runs a `while True` loop that can only exit if + there is an exception. + """ + + exception = None + try: + return_value = task.result() + raise AssertionError( + f"The engine background task should never finish without an " + f"exception. {return_value}") + except asyncio.exceptions.CancelledError: + # We assume that if the task is cancelled, we are gracefully shutting + # down. This should only happen on program exit. + logger.info("Engine is gracefully shutting down.") + except Exception as e: + exception = e + logger.error("Engine background task failed", exc_info=e) + error_callback(exception) + raise AsyncEngineDeadError( + "Task finished unexpectedly. This should never happen! " + "Please open an issue on GitHub. See stack trace above for the " + "actual cause.") from e + + +STOP_ITERATION = Exception() # Sentinel + + +class AsyncStream: + """A stream of RequestOutputs or PoolingRequestOutputs for a request + that can be iterated over asynchronously via an async generator.""" + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self.request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + + def put(self, item: Union[RequestOutput, PoolingRequestOutput, + Exception]) -> None: + if not self._finished: + self._queue.put_nowait(item) + + def finish( + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, + ) -> None: + if not self._finished: + self._finished = True + self._queue.put_nowait( + exception if self._is_raisable(exception) else STOP_ITERATION) + + @property + def finished(self) -> bool: + return self._finished + + async def generator( + self + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel(self.request_id) + raise asyncio.CancelledError from None + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or \ + (isinstance(value, type) and \ + issubclass(value, BaseException)) + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, + dict]] = asyncio.Queue() + self.new_requests_event = asyncio.Event() + + def __contains__(self, item): + return item in self._request_streams + + def __len__(self) -> int: + return len(self._request_streams) + + def propagate_exception(self, + exc: Exception, + request_id: Optional[str] = None) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self.abort_request(request_id, exception=exc) + else: + # NB: tuple() used here because self.abort_request pops the stream + # out of self._request_streams, so we can't iterate on it directly + for rid in tuple(self._request_streams.keys()): + self.abort_request(rid, exception=exc) + + def process_request_output(self, + request_output: Union[RequestOutput, + PoolingRequestOutput], + *, + verbose: bool = False) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + finished = request_output.finished + + if finished: + stream = self._request_streams.pop(request_id, None) + else: + stream = self._request_streams.get(request_id) + # Guard against a KeyError which can occur if the request was aborted + # while the output was generated + if stream is not None: + stream.put(request_output) + if finished: + stream.finish() + + if verbose and finished: + logger.info("Finished request %s.", request_id) + + def process_exception(self, + request_id: str, + exception: BaseException, + *, + verbose: bool = False) -> None: + """Propagate an exception from the engine.""" + if verbose: + logger.info("Finished request %s.", request_id) + self.abort_request(request_id, exception=exception) + + def add_request(self, + request_id: str, + *, + verbose: bool = False, + **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + abort_request = partial(self.abort_request, verbose=verbose) + stream = AsyncStream(request_id, abort_request) + self._new_requests.put_nowait((stream, { + "request_id": request_id, + **engine_add_request_kwargs + })) + + self.new_requests_event.set() + + if verbose: + logger.info("Added request %s.", request_id) + + return stream + + def abort_request(self, + request_id: str, + *, + exception: Optional[Union[BaseException, + Type[BaseException]]] = None, + verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info("Aborted request %s.", request_id) + + self._aborted_requests.put_nowait(request_id) + + stream = self._request_streams.pop(request_id, None) + if stream is not None: + stream.finish(exception=exception) + + def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._aborted_requests.empty(): + request_id = self._aborted_requests.get_nowait() + finished_requests.add(request_id) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + request_id = stream.request_id + if request_id in finished_requests: + # The request has already been aborted. + stream.finish(asyncio.CancelledError) + finished_requests.discard(request_id) + else: + self._request_streams[request_id] = stream + new_requests.append(new_request) + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + if not self.has_new_requests(): + await self.new_requests_event.wait() + self.new_requests_event.clear() + + def has_new_requests(self): + return not self._new_requests.empty() + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def step_async( + self, virtual_engine: int + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + if not scheduler_outputs.is_empty(): + # this will cause mamba_cache/minimax_cache failed + # to release finished_requests_ids of the last steps + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + virtual_engine=virtual_engine, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + # Execute the model. + outputs = await self.model_executor.execute_model_async( + execute_model_req) + + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # Clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len( + outputs + ) == 1, "Async postprocessor expects only a single output set" + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + return ctx.request_outputs + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Stop the remote worker execution loop.""" + await self.model_executor.stop_remote_worker_execution_loop_async() + + async def get_tokenizer_async(self, + lora_request: Optional[LoRARequest] = None + ) -> AnyTokenizer: + return await ( + self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) + + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> None: + """ + Async version of + [`add_request`][vllm.engine.llm_engine.LLMEngine.add_request]. + """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: + arrival_time = time.time() + + if data_parallel_rank is not None: + raise ValueError("Targeting data_parallel_rank only supported " + "in v1 client.") + + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + # We use the -2 dimension (instead of 0) in case a batched input + # of batch size 1 is passed in. + prompt["prompt_token_ids"] = [0 + ] * prompt["prompt_embeds"].shape[-2] + + processed_inputs = await self.input_preprocessor.preprocess_async( + prompt, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + # Guided decoding has an async implementation for building logits + # processors in a separate threadpool. + # We want to invoke that here instead of using the blocking + # implementation in the LLMEngine + params = await build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer_async(lora_request), + default_guided_backend=self.decoding_config.backend, + reasoning_backend=self.decoding_config.reasoning_backend, + model_config=self.model_config) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + async def check_health_async(self) -> None: + self.model_executor.check_health() + + async def collective_rpc_async(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + raise NotImplementedError + + +async def build_guided_decoding_logits_processor_async( + sampling_params: SamplingParams, tokenizer: AnyTokenizer, + default_guided_backend: str, reasoning_backend: Optional[str], + model_config: ModelConfig) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Modifies sampling params in-place and returns + the modified sampling params.""" + if sampling_params.guided_decoding is None: + return sampling_params + + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + + logger.debug( + "Building guided decoding logits processor. " + "guided_decoding: %s%s", guided_decoding, + f", reasoning_backend: {reasoning_backend}" + if reasoning_backend is not None else "") + + guided_decoding.backend = guided_decoding.backend or default_guided_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, + tokenizer=tokenizer, + reasoning_backend=reasoning_backend, + model_config=model_config) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params + + +class AsyncLLMEngine(EngineClient): + """An asynchronous wrapper for [`LLMEngine`][vllm.LLMEngine]. + + This class is used to wrap the [`LLMEngine`][vllm.LLMEngine] class to + make it asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The [`LLMEngine`][vllm.LLMEngine] is kicked + by the generate method when there are requests in the waiting queue. The + generate method yields the outputs from the [`LLMEngine`][vllm.LLMEngine] + to the caller. + + Args: + log_requests: Whether to log the requests. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. + *args: Arguments for [`LLMEngine`][vllm.LLMEngine]. + **kwargs: Arguments for [`LLMEngine`][vllm.LLMEngine]. + """ + + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + + def __init__(self, + *args, + log_requests: bool = True, + start_engine_loop: bool = True, + **kwargs) -> None: + if envs.VLLM_USE_V1: + raise ValueError( + "Using V0 AsyncLLMEngine, but envs.VLLM_USE_V1=True. " + "This should not happen. As a workaround, try using " + "AsyncLLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.log_requests = log_requests + self.engine = self._engine_class(*args, **kwargs) + + # This ensures quick processing of request outputs + # so the append to asyncio queues is not delayed, + # especially for multi-step. + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + + if self.use_process_request_outputs_callback: + self.engine.process_request_outputs_callback = \ + weak_bind(self.process_request_outputs) + + self.background_loop: Optional[asyncio.Future] = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded: Optional[asyncio.Task] = None + self.start_engine_loop = start_engine_loop + self._errored_with: Optional[BaseException] = None + + # Lazy initialized fields + self._request_tracker: RequestTracker + + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + return LLMEngine._get_executor_cls(engine_config) + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + disable_log_requests: bool = False, + disable_log_stats: bool = False, + ) -> "AsyncLLMEngine": + """Create an AsyncLLMEngine from the EngineArgs.""" + + return cls( + vllm_config=vllm_config, + executor_class=cls._get_executor_cls(vllm_config), + start_engine_loop=start_engine_loop, + log_requests=not disable_log_requests, + log_stats=not disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + + vllm_config = engine_args.create_engine_config(usage_context) + + async_engine_cls = cls + if envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine + async_engine_cls = V1AsyncLLMEngine + + return async_engine_cls.from_vllm_config( + vllm_config=vllm_config, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + disable_log_stats=engine_args.disable_log_stats, + disable_log_requests=engine_args.disable_log_requests, + ) + + @property + def is_running(self) -> bool: + return (self.background_loop is not None + and self._background_loop_unshielded is not None + and not self._background_loop_unshielded.done()) + + @property + def is_stopped(self) -> bool: + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None + and self._background_loop_unshielded.done()) + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + def set_errored(self, exc: Exception) -> None: + self._errored_with = exc + + def _error_callback(self, exc: Exception) -> None: + self.set_errored(exc) + self._request_tracker.propagate_exception(exc) + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.engine.input_preprocessor + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return await self.engine.get_tokenizer_async(lora_request) + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.errored: + raise AsyncEngineDeadError( + "Background loop has errored already.") from self._errored_with + if self.is_running: + raise RuntimeError("Background loop is already running.") + # Initialize the RequestTracker here so it uses the right event loop. + self._request_tracker = RequestTracker() + + self._background_loop_unshielded = asyncio.get_event_loop( + ).create_task(self.run_engine_loop(weakref.ref(self))) + self._background_loop_unshielded.add_done_callback( + partial(_log_task_completion, error_callback=self._error_callback)) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def shutdown_background_loop(self) -> None: + """ + Shut down the background loop. + + This method needs to be called during cleanup to remove + references to `self` and properly GC the resources held + by the async LLM engine (e.g., the executors as well as + their resources). + """ + if self._background_loop_unshielded is not None: + self._background_loop_unshielded.cancel() + self._background_loop_unshielded = None + self.background_loop = None + + async def engine_step(self, virtual_engine: int) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, aborted_requests = ( + self._request_tracker.get_new_and_aborted_requests()) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + try: + await self.engine.add_request_async(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.log_requests, + ) + + if aborted_requests: + await self._engine_abort(aborted_requests) + + request_outputs = await self.engine.step_async(virtual_engine) + + # Put the outputs into the corresponding streams. + # If used as a callback, then already invoked inside + # LLMEngine's _process_model_outputs + if not self.use_process_request_outputs_callback: + all_finished = self.process_request_outputs(request_outputs) + else: + # For callback case, we only need to detect when all + # requests are finished + all_finished = all(request_output.finished + for request_output in request_outputs) + + return not all_finished + + def process_request_outputs(self, request_outputs) -> bool: + # Put the outputs into the corresponding streams. + all_finished = True + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + all_finished = all_finished and request_output.finished + + return all_finished + + async def _engine_abort(self, request_ids: Iterable[str]): + self.engine.abort_request(request_ids) + + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional[AsyncLLMEngine] = engine_ref() + if not engine: + return + + pipeline_parallel_size = \ + engine.engine.parallel_config.pipeline_parallel_size + has_requests_in_progress = [False] * pipeline_parallel_size + while True: + if not any(has_requests_in_progress): + logger.debug("Waiting for new requests...") + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting + # indefinitely in torch.distributed ops which may otherwise + # timeout, and unblocks the RPC thread in the workers so that + # they can process any other queued control plane messages, + # such as add/remove lora adapters. + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return + logger.debug("Got new requests!") + requests_in_progress = [ + asyncio.create_task(engine.engine_step(ve)) + for ve in range(pipeline_parallel_size) + ] + has_requests_in_progress = [True] * pipeline_parallel_size + + # Abort if iteration takes too long due to unrecoverable errors + # (eg. NCCL timeouts). + try: + async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): + done, _ = await asyncio.wait( + requests_in_progress, + return_when=asyncio.FIRST_COMPLETED) + for _ in range(pipeline_parallel_size): + await asyncio.sleep(0) + for task in done: + result = task.result() + virtual_engine = requests_in_progress.index(task) + has_unfinished_requests = ( + engine.engine. + has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if result or has_unfinished_requests: + requests_in_progress[virtual_engine] = ( + asyncio.create_task( + engine.engine_step(virtual_engine))) + has_requests_in_progress[virtual_engine] = True + else: + has_requests_in_progress[virtual_engine] = False + except asyncio.TimeoutError as exc: + logger.error( + "Engine iteration timed out. This should never happen!") + engine.set_errored(exc) + raise + await asyncio.sleep(0) + + async def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + stream = self._request_tracker.add_request( + request_id, + verbose=self.log_requests, + prompt=prompt, + params=params, + arrival_time=arrival_time or time.time(), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + data_parallel_rank=data_parallel_rank, + ) + + return stream.generator() + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. + data_parallel_rank: The (global) data parallel rank that must + handle this request. Only applicable if DP is enabled. + Yields: + The output `RequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + [`engine_step`][vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step] + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> # note that engine_args here is AsyncEngineArgs instance + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + try: + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + data_parallel_rank=data_parallel_rank, + ): + yield LLMEngine.validate_output(output, RequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. + + Yields: + The output `PoolingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + [`vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`][] + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + ``` + # Please refer to entrypoints/api_server.py for + # the complete example. + + # initialize the engine and the example input + # note that engine_args here is AsyncEngineArgs instance + engine = AsyncLLMEngine.from_engine_args(engine_args) + example_input = { + "input": "What is LLM?", + "request_id": 0, + } + + # start the generation + results_generator = engine.encode( + example_input["input"], + PoolingParams(), + example_input["request_id"]) + + # get the results + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + # Return or raise an error + ... + final_output = request_output + + # Process and return the final output + ... + ``` + """ + try: + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, PoolingRequestOutput) + except asyncio.CancelledError: + await self.abort(request_id) + raise + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, + exception=asyncio.CancelledError, + verbose=self.log_requests) + + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + return self.engine.get_vllm_config() + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + return self.engine.get_model_config() + + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + return self.engine.get_parallel_config() + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + return self.engine.get_decoding_config() + + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + return self.engine.get_lora_config() + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: + self.engine.do_log_stats() + + async def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + t = time.perf_counter() + logger.debug("Starting health check...") + if self.is_stopped: + raise AsyncEngineDeadError("Background loop is stopped.") + + await self.engine.check_health_async() + logger.debug("Health check took %fs", time.perf_counter() - t) + + async def is_tracing_enabled(self) -> bool: + return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + self.engine.remove_logger(logger_name=logger_name) + + async def start_profile(self) -> None: + self.engine.start_profile() + + async def stop_profile(self) -> None: + self.engine.stop_profile() + + async def reset_mm_cache(self) -> None: + self.engine.reset_mm_cache() + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + self.engine.reset_prefix_cache(device) + + async def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) + + async def is_sleeping(self) -> bool: + return self.engine.is_sleeping() + + async def add_lora(self, lora_request: LoRARequest) -> None: + self.engine.add_lora(lora_request) + + async def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None): + """ + Perform a collective RPC call to the given path. + """ + return await self.engine.collective_rpc_async(method, timeout, args, + kwargs) + + +# TODO(v1): Remove this class proxy when V1 goes default. +if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLM + + AsyncLLMEngine = AsyncLLM # type: ignore diff --git a/vllm/engine/async_timeout.py b/vllm/engine/async_timeout.py new file mode 100644 index 0000000..28a023a --- /dev/null +++ b/vllm/engine/async_timeout.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Workaround for https://github.com/python/cpython/issues/86296 +# +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +from types import TracebackType +from typing import Any, Optional, Type + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout +else: + + def asyncio_timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + delay if delay is not None else None + return Timeout(deadline, loop) + + class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__(self, deadline: Optional[float], + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + The delay can be negative. + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError( + "cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + deadline argument points on the time in the same clock system + as loop.time(). + If new deadline is in the past the timeout is raised immediately. + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError( + "cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon( + self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at( + deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and \ + self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None: + if task: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py new file mode 100644 index 0000000..d3d95e6 --- /dev/null +++ b/vllm/engine/llm_engine.py @@ -0,0 +1,2143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import copy +import time +from collections import Counter as collectionsCounter +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Literal, Mapping, NamedTuple, Optional) +from typing import Sequence as GenericSequence +from typing import Set, Type, Union, cast + +import torch +from typing_extensions import TypeVar + +import vllm.envs as envs +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) +from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import ( + get_logits_processors as get_openai_logits_processors) +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs +from vllm.inputs.parse import split_enc_dec_inputs +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.logits_process import get_bad_words_logits_processors +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.processing import EncDecMultiModalProcessor +from vllm.outputs import (PoolingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, + PoolingSequenceGroupOutput, Sequence, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupOutput, SequenceStatus, CompletionSequenceGroupOutput, VLLM_INVALID_TOKEN_ID) +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +# DEBUG add cpm tokenizer +from vllm.transformers_utils.tokenizers import CPM9GTokenizer +from vllm.transformers_utils.tokenizer_group import ( + TokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind +from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.model_runner_base import InputProcessingError +from vllm.profiler.prof import profile + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + +_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) +_R = TypeVar("_R", default=Any) + + +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None + + +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] + skip: List[int] + + +class SchedulerContext: + + def __init__(self, multi_step_stream_outputs: bool = False): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + PoolingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool, + is_first_step_output: Optional[bool]): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + is_first_step_output=is_first_step_output, + skip=[])) + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The [`LLM`][vllm.LLM] class wraps this class for offline batched inference + and the [`AsyncLLMEngine`][vllm.engine.async_llm_engine.AsyncLLMEngine] + class wraps this class for online serving. + + The config arguments are derived from [`EngineArgs`][vllm.EngineArgs]. + + Args: + vllm_config: The configuration for initializing and running vLLM. + executor_class: The model executor class for managing distributed + execution. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return cast(_O, output) + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[TokenizerGroup] + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + if envs.VLLM_USE_V1: + raise ValueError( + "Using V0 LLMEngine, but envs.VLLM_USE_V1=True. " + "This should not happen. As a workaround, try using " + "LLMEngine.from_vllm_config(...) or explicitly set " + "VLLM_USE_V1=0 or 1 and report this issue on Github.") + + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config # noqa + self.load_config = vllm_config.load_config + self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + + logger.info( + "Initializing a V0 LLM engine (v%s) with config: %s, " + "use_cached_outputs=%s, ", + VLLM_VERSION, + vllm_config, + use_cached_outputs, + ) + + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init and self.model_config.tokenizer_mode != "cpm": + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + elif self.model_config.tokenizer_mode == "cpm": + self.tokenizer = CPM9GTokenizer(self.model_config.model, trust_remote_code=True) + self.detokenizer = Detokenizer(self.tokenizer, self.model_config.tokenizer_mode) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, ("tokenizer_group cannot be None, " + "make sure skip_tokenizer_init is False") + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = ( + self.model_config.try_get_generation_config()) + + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer, + mm_registry) + + self.model_executor = executor_class(vllm_config=vllm_config) + + if self.model_config.runner_type != "pooling": + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(self.model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(self.model_config.dtype), + "tensor_parallel_size": + self.parallel_config.tensor_parallel_size, + "block_size": + self.cache_config.block_size, + "gpu_memory_utilization": + self.cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + self.model_config.quantization, + "kv_cache_dtype": + str(self.cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(self.lora_config), + "enable_prompt_adapter": + bool(self.prompt_adapter_config), + "enable_prefix_caching": + self.cache_config.enable_prefix_caching, + "enforce_eager": + self.model_config.enforce_eager, + "disable_custom_all_reduce": + self.parallel_config.disable_custom_all_reduce, + }) + + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if self.model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + self.vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = self.vllm_config.scheduler_config.scheduler_cls + self.scheduler = [ + Scheduler( + self.scheduler_config, self.cache_config, self.lora_config, + self.parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] + if self.model_config.use_async_output_proc else None) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + vllm_config=vllm_config), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict( + model_name=self.model_config.served_model_name), + vllm_config=vllm_config), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker(self.scheduler_config.max_model_len, + get_tokenizer_for_seq), + )) + + self.tree_decoding = os.environ.get('VLLM_TREE_DECODING') == '1' + + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + + # Flag to set when an input fails to process and the engine should run + # the next step without re-scheduling. + self._skip_scheduling_next_step = False + profile.StartTracer() + + # Don't keep the dummy data in memory + self.reset_mm_cache() + + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + start = time.time() + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) + + @classmethod + def _get_executor_cls(cls, + engine_config: VllmConfig) -> Type[ExecutorBase]: + # distributed_executor_backend must be set in VllmConfig.__post_init__ + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + # Initialize the cluster and specify the executor class. + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. + from vllm.executor.uniproc_executor import UniProcExecutor + executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("unrecognized distributed_executor_backend: " + f"{distributed_executor_backend}") + return executor_class + + @classmethod + def from_vllm_config( + cls, + vllm_config: VllmConfig, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + disable_log_stats: bool = False, + ) -> "LLMEngine": + return cls( + vllm_config=vllm_config, + executor_class=cls._get_executor_cls(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + vllm_config = engine_args.create_engine_config(usage_context) + + engine_cls = cls + if envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + engine_cls = V1LLMEngine + + return engine_cls.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + stat_loggers=stat_loggers, + disable_log_stats=engine_args.disable_log_stats, + ) + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + return self.tokenizer + + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if self.model_config.tokenizer_mode == "cpm": + return self.tokenizer + else: + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) + + def _init_tokenizer(self) -> TokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + lora_config=self.lora_config) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: ProcessorInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> Optional[SequenceGroup]: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + processed_inputs=processed_inputs, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + return None + + self._validate_model_inputs(processed_inputs, lora_request) + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + #DEBUG @TODO change tokenizer false + if self.model_config.tokenizer_mode == "cpm": + eos_token_id = self.tokenizer.eos_id + else: + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, + lora_request, prompt_adapter_request) + + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + + return seq_group + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt to the LLM. See + [PromptType][vllm.inputs.PromptType] + for more details about the format of each input. + params: Parameters for sampling or pooling. + [SamplingParams][vllm.SamplingParams] for text generation. + [PoolingParams][vllm.PoolingParams] for pooling. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + lora_request: The LoRA request to add. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: The prompt adapter request to add. + priority: The priority of the request. + Only applicable with priority scheduling. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `n` number of [Sequence][vllm.Sequence] objects. + - Create a [SequenceGroup][vllm.SequenceGroup] object + from the list of [Sequence][vllm.Sequence]. + - Add the [SequenceGroup][vllm.SequenceGroup] object to the + scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if not isinstance(request_id, str): + raise TypeError( + f"request_id must be a string, got {type(request_id)}") + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + if isinstance(params, SamplingParams) \ + and (params.guided_decoding or params.logits_processors) \ + and self.scheduler_config.num_scheduler_steps > 1: + raise ValueError( + "Guided decoding and logits processors are not supported " + "in multi-step decoding") + + if arrival_time is None: + arrival_time = time.time() + + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + seq_len = prompt["prompt_embeds"].shape[0] + prompt["prompt_token_ids"] = [0] * seq_len + + #DEBUG anrongqiao + if self.model_config.tokenizer_mode == "cpm": + lora_request = None + + processed_inputs = self.input_preprocessor.preprocess( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + + sampling_params.update_from_generation_config( + self.generation_config_fields, seq.eos_token_id) + + # Create the sequence group. + draft_size = 1 + if self.vllm_config.speculative_config is not None: + draft_size = \ + self.vllm_config.speculative_config.num_speculative_tokens + 1 + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority, + draft_size=draft_size) + + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + return seq_group + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to [vllm.core.scheduler.Scheduler.abort_seq_group][]. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + for scheduler in self.scheduler: + scheduler.abort_seq_group( + request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) + + def get_vllm_config(self) -> VllmConfig: + """Gets the vllm configuration.""" + return self.vllm_config + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() + + def reset_mm_cache(self) -> bool: + """Reset the multi-modal cache.""" + return self.input_preprocessor.mm_registry.reset_processor_cache() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + """Reset prefix cache for all devices.""" + + success = True + for scheduler in self.scheduler: + success = success and scheduler.reset_prefix_cache(device) + return success + + @staticmethod + def _process_sequence_group_outputs( + seq_group: SequenceGroup, + outputs: List[PoolingSequenceGroupOutput], + ) -> None: + seq_group.pooled_data = outputs[0].data + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + + def _update_num_computed_tokens_for_multi_step_prefill( + self, seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. + + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed + """ + + now = time.time() + + if len(ctx.output_queue) == 0: + return None + + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if self.scheduler_config.is_multi_step: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, len(seq_group_metadata_list)) + elif self.speculative_config: + # Decodes are multi-steps while prefills are not, outputting at + # most 1 token. Separate them so that we can trigger chunk + # processing without having to pad or copy over prompts K times + # to match decodes structure (costly with prompt_logprobs). + num_prefills = sum(sg.is_prompt + for sg in seq_group_metadata_list) + prefills, decodes = outputs[:num_prefills], outputs[ + num_prefills:] + outputs_by_sequence_group = create_output_by_sequence_group( + decodes, + num_seq_groups=len(seq_group_metadata_list) - num_prefills) + outputs_by_sequence_group = [p.outputs for p in prefills + ] + outputs_by_sequence_group + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None + elif len(outputs) == 1: + outputs_by_sequence_group = outputs + else: + return None + + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + + finished_before: List[int] = [] + finished_now: List[int] = [] + empty_seq_indices: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group: SequenceGroup = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + finished_before.append(i) + continue + + output: List[SequenceGroupOutput] + if has_multiple_outputs: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + # tree style speculative decoding may generate empty output in first step + if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput): + samples = [o.samples[0] for o in output] + valid_samples = [ + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID + ] + if len(valid_samples) == 0: + empty_seq_indices.append(i) + continue + + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size or 0) + + if outputs: + for o in outputs: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time or 0) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time or 0) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) + + if self.model_config.runner_type == "pooling": + self._process_sequence_group_outputs(seq_group, output) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) + + if seq_group.is_finished(): + finished_now.append(i) + + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices: + continue # Avoids double processing + + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) + if request_output: + ctx.request_outputs.append(request_output) + + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) + + # Tracing + self.do_tracing(scheduler_outputs, finished_before) + + return None + + def _advance_to_next_step( + self, output: SamplerOutput, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) + else: + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) + + def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + +
+ ![Overview of the step function](https://i.imgur.com/sv2HssD.png) +
Overview of the step function
+
+ + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the distributed executor to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + ``` + # Please see the example/ folder for more detailed examples. + + # initialize engine and request arguments + engine = LLMEngine.from_engine_args(engine_args) + example_inputs = [(0, "What is LLM?", + SamplingParams(temperature=0.0))] + + # Start the engine with an event loop + while True: + if example_inputs: + req_id, prompt, sampling_params = example_inputs.pop(0) + engine.add_request(str(req_id),prompt,sampling_params) + + # continue the request processing + request_outputs = engine.step() + for request_output in request_outputs: + if request_output.finished: + # return or show the request output + + if not (engine.has_unfinished_requests() or example_inputs): + break + ``` + """ + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0. + virtual_engine = 0 + + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + # The scheduler is also skipped if a single request caused the last + # engine step to fail, and the previous schedule needs to be rerun. + if not self._has_remaining_steps( + seq_group_metadata_list + ) and not self._skip_scheduling_next_step: + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # When n>1, elements in self.seq_id_to_seq_group should be deleted + # here, otherwise memory leaks. + for finished_request_id in finished_requests_ids: + if finished_request_id in self.seq_id_to_seq_group: + del self.seq_id_to_seq_group[finished_request_id] + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + else: + finished_requests_ids = list() + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + try: + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + self._skip_scheduling_next_step = False + except InputProcessingError as e: + # The input for this request cannot be processed, so we must + # abort it. If there are remaining requests in the batch that + # have been scheduled, they will be retried on the next step. + invalid_request_id = e.request_id + self._abort_and_cache_schedule( + request_id=invalid_request_id, + virtual_engine=virtual_engine, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + allow_async_output_proc=allow_async_output_proc) + # Raise so the caller is notified that this request failed + raise + + # We need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + # No outputs in this case + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps. + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + # Add results to the output_queue + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") + + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + # Check if need to run the usual non-async path + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + + return ctx.request_outputs + + def _abort_and_cache_schedule( + self, request_id: str, virtual_engine: int, + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + """Aborts a single request, and caches the scheduler outputs minus that + request. This allows the next step to continue processing the remaining + requests without having to re-run the scheduler.""" + + # Abort the request and remove its sequence group from the current + # schedule + self.abort_request(request_id) + for i, metadata in enumerate(seq_group_metadata_list): + if metadata.request_id == request_id: + del seq_group_metadata_list[i] + break + for i, group in enumerate(scheduler_outputs.scheduled_seq_groups): + if group.seq_group.request_id == request_id: + del scheduler_outputs.scheduled_seq_groups[i] + break + + # If there are still other sequence groups left in the schedule, cache + # them and flag the engine to reuse the schedule. + if len(seq_group_metadata_list) > 0: + self._skip_scheduling_next_step = True + # Reuse multi-step caching logic + self._cache_scheduler_outputs_for_multi_step( + virtual_engine=virtual_engine, + scheduler_outputs=scheduler_outputs, + seq_group_metadata_list=seq_group_metadata_list, + allow_async_output_proc=allow_async_output_proc) + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError("All running sequence groups should " + "have the same remaining steps.") + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: + """Forced log when no requests active.""" + if self.log_stats: + stats = self._get_stats(scheduler_outputs, model_output, + finished_before, skip) + for logger in self.stat_loggers.values(): + logger.log(stats) + + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. + """ + now = time.time() + + # System State + # Scheduler State + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) + + # KV Cache Usage in % + num_total_gpu = self.cache_config.num_gpu_blocks + gpu_cache_usage_sys = 0. + if num_total_gpu: # Guard against both None and 0 + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + + num_total_cpu = self.cache_config.num_cpu_blocks + cpu_cache_usage_sys = 0. + if num_total_cpu: # Guard against both None and 0 + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + + # Exchange the uasge and cache hit stats between gpu and cpu when + # running on cpu because the cpu_worker.py intentionally reports the + # number of cpu blocks as gpu blocks in favor of cache management. + if self.device_config.device_type == "cpu": + num_total_gpu, num_total_cpu = num_total_cpu, num_total_gpu + gpu_cache_usage_sys, cpu_cache_usage_sys = ( + cpu_cache_usage_sys, + gpu_cache_usage_sys, + ) + gpu_prefix_cache_hit_rate, cpu_prefix_cache_hit_rate = ( + cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate, + ) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + num_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + time_queue_requests: List[float] = [] + time_inference_requests: List[float] = [] + time_prefill_requests: List[float] = [] + time_decode_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + n_requests: List[int] = [] + max_num_generation_tokens_requests: List[int] = [] + max_tokens_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # LoRA requests + running_lora_adapters = dict( + collectionsCounter([ + running_request.lora_request.lora_name + for scheduler in self.scheduler + for running_request in scheduler.running + if running_request.lora_request + ])) + waiting_lora_adapters = dict( + collectionsCounter([ + waiting_request.lora_request.lora_name + for scheduler in self.scheduler + for waiting_request in scheduler.waiting + if waiting_request.lora_request + ])) + max_lora_stat = "0" + if self.lora_config: + max_lora_stat = str(self.lora_config.max_loras) + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. + if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + + num_generation_tokens_from_prefill_groups = 0 + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + + group_was_prefill = idx < scheduler_outputs.num_prefill_groups + seq_group = scheduled_seq_group.seq_group + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_token_latency() + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_token_latency() + time_per_output_tokens_iter.append(latency) + if seq_group.state.current_step == 0: + # For async_output_proc, the do_log_stats() + # is called following init_multi_step(), which + # sets the current_step to zero. + actual_num_batched_tokens +=\ + seq_group.state.num_steps - 1 + else: + actual_num_batched_tokens +=\ + seq_group.state.current_step - 1 + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. + if seq_group.is_finished(): + # Latency timings + time_e2e_requests.append(now - + seq_group.metrics.arrival_time) + if (seq_group.metrics.first_scheduled_time is not None and + seq_group.metrics.first_token_time is not None): + time_queue_requests.append( + seq_group.metrics.first_scheduled_time - + seq_group.metrics.arrival_time) + time_prefill_requests.append( + seq_group.metrics.first_token_time - + seq_group.metrics.first_scheduled_time) + time_decode_requests.append( + now - seq_group.metrics.first_token_time) + time_inference_requests.append( + now - seq_group.metrics.first_scheduled_time) + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + max_num_generation_tokens_requests.append( + max(seq.get_output_len() + for seq in seq_group.get_seqs())) + if seq_group.sampling_params is not None: + n_requests.append(seq_group.sampling_params.n) + max_tokens_requests.append( + seq_group.sampling_params.max_tokens) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + actual_num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) + num_tokens_iter = (num_generation_tokens_iter + + num_prompt_tokens_iter) + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and isinstance(model_output[0], SamplerOutput) and ( + model_output[0].spec_decode_worker_metrics is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + + return Stats( + now=now, + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + num_tokens_iter=num_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + time_queue_requests=time_queue_requests, + time_inference_requests=time_inference_requests, + time_prefill_requests=time_prefill_requests, + time_decode_requests=time_decode_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + max_num_generation_tokens_requests= + max_num_generation_tokens_requests, + n_requests=n_requests, + max_tokens_requests=max_tokens_requests, + finished_reason_requests=finished_reason_requests, + max_lora=str(max_lora_stat), + waiting_lora_adapters=list(waiting_lora_adapters.keys()), + running_lora_adapters=list(running_lora_adapters.keys())) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + + def start_profile(self) -> None: + self.model_executor.start_profile() + + def stop_profile(self) -> None: + self.model_executor.stop_profile() + + def sleep(self, level: int = 1) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.sleep(level=level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + assert self.vllm_config.model_config.enable_sleep_mode, ( + "Sleep mode is not enabled in the model config") + self.model_executor.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.model_executor.is_sleeping + + def check_health(self) -> None: + self.model_executor.check_health() + + def is_tracing_enabled(self) -> bool: + return self.tracer is not None + + def do_tracing(self, + scheduler_outputs: SchedulerOutputs, + finished_before: Optional[List[int]] = None) -> None: + if self.tracer is None: + return + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double tracing when using async output proc + if finished_before and idx in finished_before: + continue + + seq_group = scheduled_seq_group.seq_group + if seq_group.is_finished(): + self.create_trace_span(seq_group) + + def create_trace_span(self, seq_group: SequenceGroup) -> None: + if self.tracer is None or seq_group.sampling_params is None: + return + arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + + trace_context = extract_trace_context(seq_group.trace_headers) + + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as seq_span: + metrics = seq_group.metrics + ttft = metrics.first_token_time - metrics.arrival_time + e2e_time = metrics.finished_time - metrics.arrival_time + seq_span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL, + self.model_config.model) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + seq_group.request_id) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, + seq_group.sampling_params.temperature) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, + seq_group.sampling_params.top_p) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + seq_group.sampling_params.max_tokens) + seq_span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, + seq_group.sampling_params.n) + seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_NUM_SEQUENCES, + seq_group.num_seqs()) + seq_span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(seq_group.prompt_token_ids)) + seq_span.set_attribute( + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + sum([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ])) + seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + metrics.time_in_queue) + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + seq_span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + if metrics.scheduler_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER, + metrics.scheduler_time) + if metrics.model_forward_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD, + metrics.model_forward_time / 1000.0) + if metrics.model_execute_time is not None: + seq_span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE, + metrics.model_execute_time) + + def _validate_model_inputs(self, inputs: ProcessorInputs, + lora_request: Optional[LoRARequest]): + encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs) + + if encoder_inputs is not None: + self._validate_model_input(encoder_inputs, + lora_request, + prompt_type="encoder") + + self._validate_model_input(decoder_inputs, + lora_request, + prompt_type="decoder") + + def _validate_model_input( + self, + prompt_inputs: SingletonInputs, + lora_request: Optional[LoRARequest], + *, + prompt_type: Literal["encoder", "decoder"], + ): + model_config = self.model_config + if self.tokenizer is None: + tokenizer = None + elif self.model_config.tokenizer_mode != "cpm": + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + else: + tokenizer = self.tokenizer + # tokenizer = (None if self.tokenizer is None else + # self.tokenizer.get_lora_tokenizer(lora_request)) + + prompt_ids = prompt_inputs.get("prompt_token_ids", []) + if not prompt_ids: + if prompt_type == "encoder" and model_config.is_multimodal_model: + pass # Mllama may have empty encoder inputs for text-only data + elif prompt_inputs["type"] == "embeds": + pass + else: + raise ValueError(f"The {prompt_type} prompt cannot be empty") + + if tokenizer is not None: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") + + max_prompt_len = self.model_config.max_model_len + if len(prompt_ids) > max_prompt_len: + if prompt_type == "encoder" and model_config.is_multimodal_model: + mm_registry = self.input_preprocessor.mm_registry + mm_processor = mm_registry.create_processor( + model_config, + tokenizer=tokenizer or object(), # Dummy if no tokenizer + ) + assert isinstance(mm_processor, EncDecMultiModalProcessor) + + if mm_processor.pad_dummy_encoder_prompt: + return # Skip encoder length check for Whisper + + if model_config.is_multimodal_model: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + else: + suggestion = ( + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens.") + + raise ValueError( + f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"longer than the maximum model length of {max_prompt_len}. " + f"{suggestion}") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %s", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.backend + + if self.decoding_config.reasoning_backend: + logger.debug("Building with reasoning backend %s", + self.decoding_config.reasoning_backend) + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_openai_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None + + if len(sampling_params.bad_words) > 0: + tokenizer = self.get_tokenizer(lora_request) + processors = get_bad_words_logits_processors( + bad_words=sampling_params.bad_words, tokenizer=tokenizer) + logits_processors.extend(processors) + + if logits_processors: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) + + return sampling_params + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + + +if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + LLMEngine = V1LLMEngine # type: ignore diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py new file mode 100644 index 0000000..8d51f04 --- /dev/null +++ b/vllm/engine/metrics.py @@ -0,0 +1,629 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Type, Union, cast + +import numpy as np +import prometheus_client + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.executor.ray_utils import ray +from vllm.logger import init_logger + +if ray is not None: + from ray.util import metrics as ray_metrics +else: + ray_metrics = None + +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +logger = init_logger(__name__) + +prometheus_client.disable_created_metrics() + +# The begin-* and end* here are used by the documentation generator +# to extract the metrics definitions. + + +# --8<-- [start:metrics-definitions] +class Metrics: + """ + vLLM uses a multiprocessing-based frontend for the OpenAI server. + This means that we need to run prometheus_client in multiprocessing mode + See https://prometheus.github.io/client_python/multiprocess/ for more + details on limitations. + """ + + labelname_finish_reason = "finished_reason" + labelname_waiting_lora_adapters = "waiting_lora_adapters" + labelname_running_lora_adapters = "running_lora_adapters" + labelname_max_lora = "max_lora" + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + # Unregister any existing vLLM collectors (for CI/CD) + self._unregister_vllm_metrics() + + max_model_len = vllm_config.model_config.max_model_len + + # Use this flag to hide metrics that were deprecated in + # a previous release and which will be removed future + self.show_hidden_metrics = \ + vllm_config.observability_config.show_hidden_metrics + + # System stats + # Scheduler State + self.gauge_scheduler_running = self._gauge_cls( + name="vllm:num_requests_running", + documentation="Number of requests currently running on GPU.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_scheduler_waiting = self._gauge_cls( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_running_lora_adapters, + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + ], + multiprocess_mode="livemostrecent", + ) + + # KV Cache Usage in % + self.gauge_gpu_cache_usage = self._gauge_cls( + name="vllm:gpu_cache_usage_perc", + documentation="GPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames, + multiprocess_mode="sum") + + # Iteration stats + self.counter_num_preemption = self._counter_cls( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) + self.counter_prompt_tokens = self._counter_cls( + name="vllm:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames) + self.counter_generation_tokens = self._counter_cls( + name="vllm:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames) + self.histogram_iteration_tokens = self._histogram_cls( + name="vllm:iteration_tokens_total", + documentation="Histogram of number of tokens per engine_step.", + labelnames=labelnames, + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + ]) + self.histogram_time_to_first_token = self._histogram_cls( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labelnames, + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, + 2560.0 + ]) + self.histogram_time_per_output_token = self._histogram_cls( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ]) + + # Request stats + # Latency + request_latency_buckets = [ + 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, + 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + ] + self.histogram_e2e_time_request = self._histogram_cls( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_queue_time_request = self._histogram_cls( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_inference_time_request = self._histogram_cls( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_prefill_time_request = self._histogram_cls( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + self.histogram_decode_time_request = self._histogram_cls( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + labelnames=labelnames, + buckets=request_latency_buckets) + + # Metadata + self.histogram_num_prompt_tokens_request = self._histogram_cls( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = \ + self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_max_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len)) + self.histogram_n_request = self._histogram_cls( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.histogram_max_tokens_request = self._histogram_cls( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.counter_request_success = self._counter_cls( + name="vllm:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) + + # Speculative decoding stats + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( + name="vllm:spec_decode_draft_acceptance_rate", + documentation="Speulative token acceptance rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_spec_decode_efficiency = self._gauge_cls( + name="vllm:spec_decode_efficiency", + documentation="Speculative decoding system efficiency.", + labelnames=labelnames, + multiprocess_mode="sum") + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) + + +# --8<-- [end:metrics-definitions] + + def _unregister_vllm_metrics(self) -> None: + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + prometheus_client.REGISTRY.unregister(collector) + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + multiprocess_mode: str = ""): + del multiprocess_mode + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + def set_to_current_time(self): + # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html + return self._gauge.set(time.time()) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) + + +class RayMetrics(Metrics): + """ + RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. + Provides the same metrics as Metrics but uses Ray's util.metrics library. + """ + _gauge_cls: Type[prometheus_client.Gauge] = cast( + Type[prometheus_client.Gauge], _RayGaugeWrapper) + _counter_cls: Type[prometheus_client.Counter] = cast( + Type[prometheus_client.Counter], _RayCounterWrapper) + _histogram_cls: Type[prometheus_client.Histogram] = cast( + Type[prometheus_client.Histogram], _RayHistogramWrapper) + + def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + if ray_metrics is None: + raise ImportError("RayMetrics requires Ray to be installed.") + super().__init__(labelnames, vllm_config) + + def _unregister_vllm_metrics(self) -> None: + # No-op on purpose + pass + + +def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values until the value exceeds the specified maximum. + + """ + exponent = 0 + buckets: List[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + +def build_1_2_5_buckets(max_value: int) -> List[int]: + """ + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + return build_buckets([1, 2, 5], max_value) + + +def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: + """ + Example: + >>> build_1_2_3_5_8_buckets(100) + [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100] + """ + return build_buckets([1, 2, 3, 5, 8], max_value) + + +def local_interval_elapsed(now: float, last_log: float, + local_interval: float) -> bool: + elapsed_time = now - last_log + return elapsed_time > local_interval + + +def get_throughput(tracked_stats: List[int], now: float, + last_log: float) -> float: + return float(np.sum(tracked_stats) / (now - last_log)) + + +class LoggingStatLogger(StatLoggerBase): + """LoggingStatLogger is used in LLMEngine to log to Stdout.""" + + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) + self.last_prompt_throughput: Optional[float] = None + self.last_generation_throughput: Optional[float] = None + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + Logs to Stdout every self.local_interval seconds.""" + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + # Compute summary metrics for tracked stats (and log them + # to promethus if applicable). + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + + log_fn = logger.info + if not any((prompt_throughput, generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + + log_fn( + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%.", + prompt_throughput, + generation_throughput, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, + ) + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + log_fn( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) + if self.spec_decode_metrics is not None: + log_fn( + self._format_spec_decode_metrics_str( + self.spec_decode_metrics)) + + self._reset(stats, prompt_throughput, generation_throughput) + + def _reset(self, stats, prompt_throughput, generation_throughput) -> None: + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + self.last_prompt_throughput = prompt_throughput + self.last_generation_throughput = generation_throughput + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens: {metrics.emitted_tokens}.") + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + +class PrometheusStatLogger(StatLoggerBase): + """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + _metrics_cls = Metrics + _gauge_cls = prometheus_client.Gauge + + def __init__(self, local_interval: float, labels: Dict[str, str], + vllm_config: VllmConfig) -> None: + super().__init__(local_interval, vllm_config) + # Prometheus metrics + self.labels = labels + self.metrics = self._metrics_cls(labelnames=list(labels.keys()), + vllm_config=vllm_config) + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + logger.warning("Skipping negative increment of %g to %s", data, + counter) + return + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram, data: Union[List[int], + List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) + + def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + gauge.labels(**data).set_to_current_time() + + def _log_prometheus(self, stats: Stats) -> None: + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + # Including max-lora in metric, in future this property of lora + # config maybe extended to be dynamic. + lora_info = { + self.metrics.labelname_running_lora_adapters: + ",".join(stats.running_lora_adapters), + self.metrics.labelname_waiting_lora_adapters: + ",".join(stats.waiting_lora_adapters), + self.metrics.labelname_max_lora: + stats.max_lora, + } + self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) + # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_iteration_tokens, + [stats.num_tokens_iter]) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + self._log_histogram(self.metrics.histogram_queue_time_request, + stats.time_queue_requests) + self._log_histogram(self.metrics.histogram_inference_time_request, + stats.time_inference_requests) + self._log_histogram(self.metrics.histogram_prefill_time_request, + stats.time_prefill_requests) + self._log_histogram(self.metrics.histogram_decode_time_request, + stats.time_decode_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram( + self.metrics.histogram_max_num_generation_tokens_request, + stats.max_num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_max_tokens_request, + stats.max_tokens_requests) + + def log(self, stats: Stats): + """Logs to prometheus and tracked stats every iteration.""" + # Log to prometheus. + self._log_prometheus(stats) + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + if self.spec_decode_metrics is not None: + self._log_gauge( + self.metrics.gauge_spec_decode_draft_acceptance_rate, + self.spec_decode_metrics.draft_acceptance_rate) + self._log_gauge(self.metrics.gauge_spec_decode_efficiency, + self.spec_decode_metrics.system_efficiency) + self._log_counter( + self.metrics.counter_spec_decode_num_accepted_tokens, + self.spec_decode_metrics.accepted_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_draft_tokens, + self.spec_decode_metrics.draft_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_emitted_tokens, + self.spec_decode_metrics.emitted_tokens) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + if type == "cache_config": + metrics_info = obj.metrics_info() + info_gauge = self._gauge_cls( + name="vllm:cache_config_info", + documentation="Information of the LLMEngine CacheConfig", + labelnames=metrics_info.keys(), + multiprocess_mode="mostrecent") + info_gauge.labels(**metrics_info).set(1) + + +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py new file mode 100644 index 0000000..9375dc4 --- /dev/null +++ b/vllm/engine/metrics_types.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +These types are defined in this file to avoid importing vllm.engine.metrics +and therefore importing prometheus_client. + +This is required due to usage of Prometheus multiprocess mode to enable +metrics after splitting out the uvicorn process from the engine process. + +Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR +before prometheus_client is imported. Typically, this is done by setting +the env variable before launch, but since we are a library, we need to +do this in Python code and lazily import prometheus_client. +""" + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +from vllm.config import SupportsMetricsInfo, VllmConfig +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + num_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + num_preemption_iter: int + + # Request stats (should have _requests suffix) + # Latency + time_e2e_requests: List[float] + time_queue_requests: List[float] + time_inference_requests: List[float] + time_prefill_requests: List[float] + time_decode_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + n_requests: List[int] + max_num_generation_tokens_requests: List[int] + max_tokens_requests: List[int] + finished_reason_requests: List[str] + waiting_lora_adapters: List[str] + running_lora_adapters: List[str] + max_lora: str + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + +class StatLoggerBase(ABC): + """Base class for StatLogger.""" + + def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: + # Tracked stats over current local logging interval. + self.num_prompt_tokens: List[int] = [] + self.num_generation_tokens: List[int] = [] + self.last_local_log = time.time() + self.local_interval = local_interval + self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None + + @abstractmethod + def log(self, stats: Stats) -> None: + raise NotImplementedError + + @abstractmethod + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def maybe_update_spec_decode_metrics(self, stats: Stats): + """Save spec decode metrics (since they are unlikely + to be emitted at same time as log interval).""" + if stats.spec_decode_metrics is not None: + self.spec_decode_metrics = stats.spec_decode_metrics diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000..db968cd --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Mapping, Optional, Union + +from vllm import PoolingParams +from vllm.inputs import PromptType +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.utils import Device + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCProcessRequest: + prompt: PromptType + params: Union[SamplingParams, PoolingParams] + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + priority: int = 0 + + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + self.priority = priority + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +class RPCResetMultiModalCacheRequest(Enum): + RESET = 1 + + +@dataclass +class RPCResetPrefixCacheRequest: + device: Device + + +class RPCSleepRequest(Enum): + SLEEP_LEVEL_1 = 1 + SLEEP_LEVEL_2 = 2 + + +@dataclass +class RPCWakeUpRequest: + tags: Optional[list[str]] = None + + +@dataclass +class RPCIsSleepingRequest: + # Set the default value of request_id to a new UUID + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RPCIsSleepingResponse: + request_id: str + is_sleeping: bool + + +@dataclass +class RPCLoadAdapterRequest: + lora_request: LoRARequest + # Set the default value of request_id to a new UUID + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class RPCAdapterLoadedResponse: + request_id: str + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest, RPCLoadAdapterRequest, + RPCResetMultiModalCacheRequest, + RPCResetPrefixCacheRequest, RPCSleepRequest, + RPCWakeUpRequest, RPCIsSleepingRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, + RPCIsSleepingResponse, RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000..90e4671 --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,686 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, cast) + +import cloudpickle +import psutil +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm import PoolingParams +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCAdapterLoadedResponse, RPCError, + RPCIsSleepingRequest, + RPCIsSleepingResponse, + RPCLoadAdapterRequest, + RPCProcessRequest, + RPCResetMultiModalCacheRequest, + RPCResetPrefixCacheRequest, + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) +from vllm.engine.protocol import EngineClient +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +from vllm.utils import Device +from vllm.transformers_utils.tokenizers import CPM9GTokenizer + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient(EngineClient): + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: VllmConfig, + engine_pid: int): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.vllm_config = engine_config + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + if self.model_config.tokenizer_mode != "cpm": + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + lora_config=engine_config.lora_config) + else: + self.tokenizer = CPM9GTokenizer(self.model_config.model, trust_remote_code=True) + self.input_preprocessor = InputPreprocessor(self.model_config, + self.tokenizer) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + + # Loop to handle output of the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.output_loop: Optional[asyncio.Task] = None + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + self._engine_process = psutil.Process(engine_pid) + + @staticmethod + def is_unsupported_config(vllm_config: VllmConfig): + # Pipeline parallel not yet supported + return vllm_config.parallel_config.pipeline_parallel_size > 1 + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.heartbeat_socket.poll(timeout=timeout): + # Heartbeat received- check the message + await self._check_success( + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) + + logger.debug("Heartbeat successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + # If engine is errored, no matter the type of exception + # it will no longer be able to receive new requests, + # therefore we have to inform that the current + # processed requests failed as well. Send back a dead + # engine error give this feedback and also give a + # 'hint' to the server to shutdown next. + exception = self.dead_error + + if request_id is None: + # If request_id is None, then the engine raised an + # exception for a batch, and we may not know the + # request that caused it, neither if it was actually + # caused by any of them (e.g. CUDA OOM). Therefore we + # broadcast the same exception for all requests. + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + # Put each output into the appropriate queue. + elif isinstance( + request_outputs, + (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): + self._add_output(request_outputs) + else: + for request_output in request_outputs: + self._add_output(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + def _add_output(self, request_output: Union[RequestOutput, + RPCAdapterLoadedResponse, + RPCIsSleepingResponse]): + queue = self.output_queues.get(request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Start output_loop + if self.output_loop is None: + # only generate once to avoid multiple concurrent output_loops + # this will lead to race conditions and wrong orders of tokens + # returned by the engine + # setup will be called multiple times during the startup of + # the engine + self.output_loop = asyncio.create_task( + self.run_output_handler_loop()) + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + if self.health_loop is None: + self.health_loop = asyncio.create_task( + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + if self.output_loop is not None: + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) if self.model_config.tokenizer_mode != "cpm" else self.tokenizer + + async def get_vllm_config(self) -> VllmConfig: + return self.vllm_config + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + """ + Ignore do_log_stats (handled on MQLLMEngine polling) + """ + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return ENGINE_DEAD_ERROR(self._errored_with) + + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the + scheduling policy is not "priority". + """ + return cast( + AsyncGenerator[RequestOutput, None], + self._process_request(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority)) + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See + [`PromptType`][vllm.inputs.PromptType] for more details about + the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `PoolingRequestOutput` objects from the LLMEngine + for the request. + """ + return cast( + AsyncGenerator[PoolingRequestOutput, None], + self._process_request(prompt, + pooling_params, + request_id, + lora_request, + trace_headers, + priority=priority)) + + async def _process_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + PoolingRequestOutput, None]]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # Ensure the request id is unique among running requests + if request_id in self.output_queues: + raise ValueError(f"Request {request_id} already exists") + + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + params = await \ + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=(self.decoding_config.backend + if self.decoding_config + else DecodingConfig.backend), + model_config=self.model_config, + reasoning_backend=self.decoding_config.reasoning_backend, + ) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if isinstance(params, SamplingParams) and params.logits_processors: + # Defensive shallow copy + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCProcessRequest( + prompt=prompt, + params=params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + )) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + + await self._send_one_way_rpc_request( + request=RPCResetMultiModalCacheRequest.RESET, + socket=self.input_socket) + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + """Reset the prefix cache""" + + await self._send_one_way_rpc_request( + request=RPCResetPrefixCacheRequest(device), + socket=self.input_socket) + + async def sleep(self, level: int = 1) -> None: + """Sleep the engine for a given level""" + return await self._send_one_way_rpc_request( + request=RPCSleepRequest(level), socket=self.input_socket) + + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + """Wake up the engine""" + return await self._send_one_way_rpc_request( + request=RPCWakeUpRequest(tags), socket=self.input_socket) + + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + request = RPCIsSleepingRequest() + + queue: asyncio.Queue[Union[BaseException, + RPCIsSleepingResponse]] = asyncio.Queue() + self.output_queues[request.request_id] = queue + + request_bytes = pickle.dumps(request) + await self.input_socket.send_multipart((request_bytes, ), copy=False) + + request_output = await queue.get() + self.output_queues.pop(request.request_id) + + if isinstance(request_output, BaseException): + raise request_output + return request_output.is_sleeping + + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + # Uses the same I/O as generate requests + request = RPCLoadAdapterRequest(lora_request) + + # Create output queue for this requests. + queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue() + self.output_queues[request.request_id] = queue + + # Send the request + request_bytes = pickle.dumps(request) + await self.input_socket.send_multipart((request_bytes, ), copy=False) + + # Wait for the response + request_output = await queue.get() + self.output_queues.pop(request.request_id) + + # Raise on error, otherwise happily return None + if isinstance(request_output, BaseException): + raise request_output diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000..e04734f --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,478 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import vllm.envs as envs +from vllm.zero_overhead.llm_engine import ZeroOverheadEngine +import zmq + +from vllm import AsyncEngineArgs, SamplingParams +from vllm.config import VllmConfig +from vllm.engine.llm_engine import LLMEngine +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCAdapterLoadedResponse, RPCError, + RPCIsSleepingRequest, + RPCIsSleepingResponse, + RPCLoadAdapterRequest, + RPCProcessRequest, + RPCResetMultiModalCacheRequest, + RPCResetPrefixCacheRequest, + RPCSleepRequest, RPCStartupRequest, + RPCStartupResponse, + RPCUProfileRequest, RPCWakeUpRequest) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.usage.usage_lib import UsageContext +from vllm.worker.model_runner_base import InputProcessingError +import time + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for + [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + + This class is used to wrap the + [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode + process is kicked off when a new RPCProcessRequest is received by the + input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + [`LLMEngine.step()`][vllm.engine.llm_engine.LLMEngine.step], and sends + the RequestOutputs back over the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + **kwargs: Arguments for [`LLMEngine`][vllm.engine.llm_engine.LLMEngine]. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + kwargs['use_cached_outputs'] = True + + if envs.VLLM_ZERO_OVERHEAD: + self.engine = ZeroOverheadEngine(*args, **kwargs) + else: + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig, + usage_context: UsageContext, + disable_log_requests: bool, disable_log_stats: bool, + ipc_path: str) -> "MQLLMEngine": + # Setup plugins for each process + from vllm.plugins import load_general_plugins + load_general_plugins() + + use_async_sockets = vllm_config.model_config.use_async_output_proc + + return cls( + vllm_config=vllm_config, + executor_class=LLMEngine._get_executor_cls(vllm_config), + ipc_path=ipc_path, + usage_context=usage_context, + use_async_sockets=use_async_sockets, + log_requests=(not disable_log_requests), + log_stats=(not disable_log_stats), + ) + + @staticmethod + def from_engine_args(engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + vllm_config = engine_args.create_engine_config(usage_context) + return MQLLMEngine.from_vllm_config( + ipc_path=ipc_path, + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats, + ) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + last_no_req_time_refreshed = True + last_no_req_time = time.perf_counter() + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + # When there's no work, check on engine health and send + # health status back to client + self._health_check() + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + last_no_req_time = time.perf_counter() + last_no_req_time_refreshed = True + + # Handle any input from the client. + self.handle_new_input() + + if envs.VLLM_TBO_REQ_DELAY_MS > 0 and last_no_req_time_refreshed and envs.VLLM_ENABLE_TBO: + if self.engine.get_num_unfinished_requests() < 2: + time_diff_ms = int((time.perf_counter() - last_no_req_time) * 1000) + if time_diff_ms < envs.VLLM_TBO_REQ_DELAY_MS: + time.sleep(0.01) # sleep and waiting more request to merge in one batch + continue + + last_no_req_time_refreshed = False + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + try: + return self.engine.step() + except SystemExit: + raise + except InputProcessingError as e: + # Special case where we handle an error preparing the inputs for + # a single request in the batch + rpc_err = RPCError(request_id=e.request_id, + is_engine_errored=False, + exception=e.__cause__) + self._send_outputs(rpc_err) + return [] + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCProcessRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) + lprocs = cloudpickle.loads(frames[1].buffer) + request.params.logits_processors = lprocs + self._handle_process_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() + elif isinstance(request, RPCLoadAdapterRequest): + self._handle_load_adapter_request(request) + elif isinstance(request, RPCResetMultiModalCacheRequest): + self.reset_mm_cache() + elif isinstance(request, RPCResetPrefixCacheRequest): + self.reset_prefix_cache() + elif isinstance(request, RPCSleepRequest): + self.sleep(request.value) + elif isinstance(request, RPCWakeUpRequest): + self.wake_up(request.tags) + elif isinstance(request, RPCIsSleepingRequest): + self._handle_is_sleeping_request(request) + else: + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e from None + + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + priority=request.priority) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + logger.debug("Failed to add request %s to engine. %s", + request.request_id, e) + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest): + try: + self.engine.add_lora(request.lora_request) + except BaseException as e: + # Send back an error if the adater fails to load + rpc_err = RPCError(request_id=request.request_id, + is_engine_errored=False, + exception=e) + self._send_outputs(rpc_err) + return + # Otherwise, send back the successful load message + self._send_outputs( + RPCAdapterLoadedResponse(request_id=request.request_id)) + + def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest): + is_sleeping = self.is_sleeping() + self._send_outputs( + RPCIsSleepingResponse(request_id=request.request_id, + is_sleeping=is_sleeping)) + + def _health_check(self): + # Send unhealthy if engine has already errored + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send outputs back to the engine client. These can be: + - Exceptions + - A list of generation outputs + - A response from loading a lora adapter + """ + if outputs: + try: + from ray.exceptions import RayTaskError + + # RayTaskError might not pickelable here. We need to unpack the + # underlying exception as the real exception in the output. + if (isinstance(outputs, RPCError) + and isinstance(outputs.exception, RayTaskError)): + outputs.exception = outputs.exception.cause + except ImportError: + pass + + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + def start_profile(self) -> None: + self.engine.start_profile() + + def stop_profile(self) -> None: + self.engine.stop_profile() + + def reset_mm_cache(self) -> bool: + return self.engine.reset_mm_cache() + + def reset_prefix_cache(self) -> bool: + return self.engine.reset_prefix_cache() + + def sleep(self, level: int = 1) -> None: + self.engine.sleep(level) + + def wake_up(self, tags: Optional[list[str]] = None) -> None: + self.engine.wake_up(tags) + + def is_sleeping(self) -> bool: + return self.engine.is_sleeping() + + +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") + + +def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, + ipc_path: str, disable_log_stats: bool, + disable_log_requests: bool, engine_alive): + try: + # Ensure we can serialize transformer config before spawning + maybe_register_config_serialize_by_value() + + engine = MQLLMEngine.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_stats=disable_log_stats, + disable_log_requests=disable_log_requests, + ipc_path=ipc_path) + + signal.signal(signal.SIGTERM, signal_handler) + + engine.start() + + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e from None diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 0000000..19c5963 --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Callable, List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py new file mode 100644 index 0000000..e0fa6a0 --- /dev/null +++ b/vllm/engine/output_processor/multi_step.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +from typing import Callable, List, cast + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + outputs: the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s + for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + assert isinstance(output, CompletionSequenceGroupOutput) + single_step_process_prompt_logprob(self, seq_group, output) + + @staticmethod + @functools.lru_cache + def _log_prompt_logprob_unsupported_warning_once(): + # Reminder: Please update docs/features/compatibility_matrix.md + # If the feature combo become valid + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization. + It also handles cases where there are tokens emitted after + the EOS token. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + # Sequences can be in RUNNING or FINISHED_ABORTED state + # once scheduled, as a sequence is moved to FINISHED_ABORTED + # if a client disconnects from the api server. + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + if seqs is None: + seqs = sequence_group.get_seqs( + status=SequenceStatus.FINISHED_ABORTED) + + for output in outputs: + if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID: + sequence_group.metrics.spec_token_acceptance_counts[ + output.step_index] += 1 + + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + seq_id = seq.seq_id + # This method is defined in the more generic + # SequenceGroupOutputProcessor, but here we assume that the outputs are + # of a more specific type. + assert all([ + isinstance(output, CompletionSequenceGroupOutput) + for output in outputs + ]) + compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs) + assert all([ + seq_id == output.samples[0].parent_seq_id + for output in compl_outputs + ]) + + if is_async: + # Async case: We process tokens one by one. Here, we know the token + # was already appended, so we only need to do the rest of the + # postprocessor: Detokenization + stopping logic + self._process_decode_and_stop(seq, sequence_group.sampling_params) + else: + # Standard multi-step case + + # Since there's only one sequence per sequence group, + # we can take the first sample. + samples = [output.samples[0] for output in compl_outputs] + + # entries in sample tokens may be invalid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID + ] + + # When both spec-decode and pre-fill chunking are enabled, we + # don't have guaranteed samples here (e.g. all -1s). + if valid_samples: + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_decode_and_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + new_char_count = 0 + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] + output_embeds = [sample.output_embed for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos and self.detokenizer: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + break + + is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id, output_logprob, output_embed in zip( + output_token_ids, output_logprobs, output_embeds): + seq.append_token_id( + token_id=output_token_id, + logprobs=output_logprob, + token_embed=output_embed, + ) + + if is_prefill_sampled_token: + is_prefill_sampled_token = False + else: + # Update num_computed_tokens iff the sampled token is not from + # a prefill step. + seq.data.update_num_computed_tokens(1) + + self._process_decode_and_stop(seq, sampling_params) + + if seq.is_finished(): + break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py new file mode 100644 index 0000000..dbf6a37 --- /dev/null +++ b/vllm/engine/output_processor/single_step.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, + SequenceGroupOutput) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: CompletionSequenceGroupOutput) -> None: + """Process prompt logprobs associated with the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: + [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor] + instance + seq_group: the output is associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] + for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this + [`SequenceGroup`][vllm.sequence.SequenceGroup] + outputs: the + [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] + for a single scheduler step + """ + assert len(outputs) == 1, "Single step should only have 1 output." + output = outputs[0] + assert isinstance(output, CompletionSequenceGroupOutput) + single_step_process_prompt_logprob(self, seq_group, output) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput, + is_async: bool) -> None: + sampling_params = seq_group.sampling_params + + sample = outputs.samples[0] + seq = seq_group.first_seq + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 0000000..3fb2f71 --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, List, Optional, Tuple + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in (sampling_params.stop_token_ids or ()): + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop = self.check_stop_strings( + seq.output_text, new_char_count, sampling_params.stop, + sampling_params.include_stop_str_in_output) + if stop is not None: + stop_str, truncate_to = stop + if truncate_to != -1: + seq.output_text = seq.output_text[:truncate_to] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() >= self._get_max_model_len(lora_req): + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def check_stop_strings( + output_text: str, + new_char_count: int, + stop: List[str], + include_in_output: bool, + ) -> Optional[Tuple[str, int]]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns tuple (stop_string, offset) if matched or else None. + + Where stop_string is the matched stop string and offset is the + length to which output_text should be truncated, or -1 for no + truncation. + """ + if not new_char_count or not stop: + return None + + for stop_str in stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = output_text.find(stop_str, + 1 - new_char_count - stop_string_len) + if stop_index == -1: + continue + + if include_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(output_text): + # No truncation required. + return stop_str, -1 + + # Truncate the output text to either the beginning + # or end of the stop string. + return stop_str, stop_index + return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 0000000..1e127eb --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import List +from typing import Sequence as GenericSequence +from typing import cast + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput + + +def create_output_by_sequence_group( + outputs: GenericSequence[SamplerOutput], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ + [] for _ in range(num_seq_groups) + ] + for step in outputs: + sequence_group_output: CompletionSequenceGroupOutput + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + # Cast to the more generic type that CompletionSequenceGroupOutput + # inherits from. + return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 0000000..8688fcc --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Mapping, Optional + +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Device, collect_from_async_generator, random_uuid + +logger = init_logger(__name__) + + +class EngineClient(ABC): + """Protocol class for Clients to Engine""" + + @property + @abstractmethod + def is_running(self) -> bool: + ... + + @property + @abstractmethod + def is_stopped(self) -> bool: + ... + + @property + @abstractmethod + def errored(self) -> bool: + ... + + @property + @abstractmethod + def dead_error(self) -> BaseException: + ... + + @abstractmethod + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request.""" + ... + + async def beam_search( + self, + prompt: PromptType, + request_id: str, + params: BeamSearchParams, + lora_request: Optional[LoRARequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + include_stop_str_in_output = params.include_stop_str_in_output + + preprocessor = await self.get_input_preprocessor() + tokenizer_group = preprocessor.get_tokenizer_group() + tokenizer = await tokenizer_group.get_lora_tokenizer_async() + + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = preprocessor._prompt_to_llm_inputs(prompt) + + if processed_inputs["type"] == "embeds": + raise NotImplementedError + + # This is a workaround to fix multimodal beam search; this is a + # bandaid fix for 2 small problems: + # 1. Multi_modal_data on the processed_inputs currently resolves to + # `None`. + # 2. preprocessing above expands the multimodal placeholders. However, + # this happens again in generation, so the double expansion causes + # a mismatch. + # TODO - would be ideal to handle this more gracefully. + prompt_token_ids = prompt.get("prompt_token_ids") + multi_modal_data = prompt.get("multi_modal_data") + + prompt_text = processed_inputs.get("prompt") + mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") + + tokenized_length = len(prompt_token_ids) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams( + logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature, + ) + all_beams = [ + BeamSearchSequence(tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + lora_request=lora_request) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch, lora_req_batch = zip(*[( + TokensPrompt(prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs), + beam.lora_request, + ) for beam in all_beams]) + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, (individual_prompt, + lora_req) in enumerate(zip(prompts_batch, lora_req_batch)): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, + beam_search_params, + request_id_item, + lora_request=lora_req))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append( + BeamSearchSequence( + tokens=current_beam.tokens + + [token_id] if include_stop_str_in_output + else current_beam.tokens, + logprobs=current_beam.logprobs + + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + finish_reason="stop", + stop_reason=tokenizer.eos_token_id)) + else: + new_beams.append( + BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs)) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): + # Skip the eos token in the text. + tokens = beam.tokens[tokenized_length:-1] + else: + tokens = beam.tokens[tokenized_length:] + beam.text = tokenizer.decode(tokens) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput(text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason if + beam.finish_reason is not None else "length", + stop_reason=beam.stop_reason) + for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=None) + + yield beam_search_output + + @abstractmethod + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """Generate outputs for a request from a pooling model.""" + ... + + @abstractmethod + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + ... + + @abstractmethod + async def get_vllm_config(self) -> VllmConfig: + """Get the vllm configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + ... + + @abstractmethod + async def get_input_preprocessor(self) -> InputPreprocessor: + """Get the input processor of the vLLM engine.""" + ... + + @abstractmethod + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get the appropriate tokenizer for the request""" + ... + + @abstractmethod + async def is_tracing_enabled(self) -> bool: + ... + + @abstractmethod + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[list[SamplerOutput]] = None, + ) -> None: + ... + + @abstractmethod + async def check_health(self) -> None: + """Raise if unhealthy""" + ... + + @abstractmethod + async def start_profile(self) -> None: + """Start profiling the engine""" + ... + + @abstractmethod + async def stop_profile(self) -> None: + """Start profiling the engine""" + ... + + @abstractmethod + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + ... + + @abstractmethod + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + """Reset the prefix cache""" + ... + + @abstractmethod + async def sleep(self, level: int = 1) -> None: + """Sleep the engine""" + ... + + @abstractmethod + async def wake_up(self, tags: Optional[list[str]] = None) -> None: + """Wake up the engine""" + ... + + @abstractmethod + async def is_sleeping(self) -> bool: + """Check whether the engine is sleeping""" + ... + + @abstractmethod + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + ... diff --git a/vllm/entrypoints/__init__.py b/vllm/entrypoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py new file mode 100644 index 0000000..3d1e5dc --- /dev/null +++ b/vllm/entrypoints/api_server.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" +import asyncio +import json +import ssl +from argparse import Namespace +from collections.abc import AsyncGenerator +from typing import Any, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +import vllm.envs as envs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.utils import with_cancellation +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server") + +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + return await _generate(request_dict, raw_request=request) + + +@with_cancellation +async def _generate(request_dict: dict, raw_request: Request) -> Response: + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\n").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + app.state.engine_client = engine + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + set_ulimit() + + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + sock=None, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=parser.check_port, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py new file mode 100644 index 0000000..4b6c505 --- /dev/null +++ b/vllm/entrypoints/chat_utils.py @@ -0,0 +1,1278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from collections.abc import Awaitable, Iterable +from functools import cached_property, lru_cache, partial +from pathlib import Path +from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, + cast) + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) +from openai.types.chat.chat_completion_content_part_input_audio_param import ( + InputAudio) +from PIL import Image +from pydantic import BaseModel, ConfigDict, TypeAdapter +# yapf: enable +from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin) +# pydantic needs the TypedDict from typing_extensions +from typing_extensions import Required, TypeAlias, TypedDict + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_cls +from vllm.model_executor.models import SupportsMultiModal +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.utils import MediaConnector +# yapf: disable +from vllm.transformers_utils.chat_templates import ( + get_chat_template_fallback_path) +# yapf: enable +from vllm.transformers_utils.processor import cached_get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import deprecate_kwargs, random_uuid + +logger = init_logger(__name__) + + +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + +class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): + image_embeds: Required[Union[str, dict[str, str]]] + """ + The image embeddings. It can be either: + - A single base64 string. + - A dictionary where each value is a base64 string. + """ + type: Required[Literal["image_embeds"]] + """The type of the content part.""" + + +class VideoURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the video or a data URL with base64 encoded video data. + """ + + +class ChatCompletionContentPartVideoParam(TypedDict, total=False): + video_url: Required[VideoURL] + + type: Required[Literal["video_url"]] + """The type of the content part.""" + + +class PILImage(BaseModel): + """ + A PIL.Image.Image object. + """ + image_pil: Image.Image + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class CustomChatCompletionContentPILImageParam(TypedDict, total=False): + """A simpler version of the param that only accepts a PIL image. + + Example: + { + "image_pil": ImageAsset('cherry_blossom').pil_image + } + """ + image_pil: Required[PILImage] + + +class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain image_url. + This is supported by OpenAI API, although it is not documented. + + Example: + { + "image_url": "https://example.com/image.jpg" + } + """ + image_url: Required[str] + + +class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "audio_url": "https://example.com/audio.mp3" + } + """ + audio_url: Required[str] + + +class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "video_url": "https://example.com/video.mp4" + } + """ + video_url: Required[str] + + +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPILImageParam, + CustomChatCompletionContentSimpleImageParam, + ChatCompletionContentPartImageEmbedsParam, + CustomChatCompletionContentSimpleAudioParam, + CustomChatCompletionContentSimpleVideoParam, str] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, list[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +# TODO: Make fields ReadOnly once mypy supports it +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Union[Optional[str], list[dict[str, str]]] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (_is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access(node, varname) + ) # yapf: enable + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +@lru_cache(maxsize=32) +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def resolve_mistral_chat_template( + chat_template: Optional[str], + **kwargs: Any, +) -> Optional[str]: + if chat_template is not None: + logger.warning_once( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning_once( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning_once( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + return None + +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) +def resolve_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + *, + model_config: ModelConfig, + trust_remote_code: Optional[bool] = None, +) -> Optional[str]: + # 1st priority: The given chat template + if chat_template is not None: + return chat_template + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, + ProcessorMixin), + trust_remote_code=model_config.trust_remote_code, + ) + if isinstance(processor, ProcessorMixin) and \ + hasattr(processor, 'chat_template') and \ + processor.chat_template is not None: + return processor.chat_template + except Exception: + logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug("Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, exc_info=True) + + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=model_config.tokenizer, + ) + if path is not None: + logger.info("Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", tokenizer.name_or_path) + chat_template = load_chat_template(path) + else: + logger.debug("There is no chat template fallback for %s", + tokenizer.name_or_path) + + return chat_template + + +def _resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + tokenizer: AnyTokenizer, + *, + model_config: ModelConfig, +) -> _ChatTemplateContentFormat: + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + else: + hf_chat_template = None + + jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) + else load_chat_template(chat_template, is_literal=True)) + + detected_format = ("string" if jinja_text is None else + _detect_content_format(jinja_text, default="string")) + + return detected_format + + +@lru_cache +def _log_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + detected_format: ChatTemplateContentFormatOption, +): + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) +def resolve_chat_template_content_format( + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, + *, + model_config: ModelConfig, + trust_remote_code: Optional[bool] = None, +) -> _ChatTemplateContentFormat: + if given_format != "auto": + return given_format + + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + tokenizer, + model_config=model_config, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + + return detected_format + + + +ModalityStr = Literal["image", "audio", "video", "image_embeds"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + + self._items_by_modality = defaultdict[str, list[_T]](list) + + @property + def model_config(self) -> ModelConfig: + return self._model_config + + @cached_property + def model_cls(self): + return get_model_cls(self.model_config) + + @property + def allowed_local_media_path(self): + return self._model_config.allowed_local_media_path + + @property + def mm_registry(self): + return MULTIMODAL_REGISTRY + + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + mm_registry = self.mm_registry + model_config = self.model_config + model_cls = cast(SupportsMultiModal, self.model_cls) + + input_modality = modality.replace("_embeds", "") + + if mm_registry.has_processor(model_config): + mm_processor = mm_registry.create_processor(model_config) + allowed_counts = mm_processor.info.get_allowed_mm_limits() + allowed_count = allowed_counts.get(input_modality, 0) + else: + mm_config = model_config.multimodal_config + if mm_config is None: + msg = "This model does not support multi-modal inputs" + raise ValueError(msg) + + allowed_count = mm_config.get_limit_per_prompt(input_modality) + + current_count = len(self._items_by_modality[modality]) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request. You can set `--limit-mm-per-prompt` to " + "increase this limit if the model supports it.") + + self._items_by_modality[modality].append(item) + + return model_cls.get_placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[object]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + if not self._items_by_modality: + return None + mm_inputs = {} + items_by_modality = dict(self._items_by_modality) + if "image" in items_by_modality and "image_embeds" in items_by_modality: + raise ValueError(\ + "Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in items_by_modality: + image_embeds_lst = items_by_modality["image_embeds"] + if len(image_embeds_lst) > 1: + raise ValueError(\ + "Only one message can have {'type': 'image_embeds'}") + mm_inputs["image"] = image_embeds_lst[0] + if "image" in items_by_modality: + mm_inputs["image"] = items_by_modality["image"] # A list of images + if "audio" in items_by_modality: + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + if "video" in items_by_modality: + mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if not self._items_by_modality: + return None + mm_inputs = {} + items_by_modality = { + modality: await asyncio.gather(*items) + for modality, items in self._items_by_modality.items() + } + + if "image" in items_by_modality and "image_embeds" in items_by_modality: + raise ValueError( + "Mixing raw image and embedding inputs is not allowed") + + if "image_embeds" in items_by_modality: + image_embeds_lst = items_by_modality["image_embeds"] + if len(image_embeds_lst) > 1: + raise ValueError( + "Only one message can have {'type': 'image_embeds'}") + mm_inputs["image"] = image_embeds_lst[0] + if "image" in items_by_modality: + mm_inputs["image"] = items_by_modality["image"] # A list of images + if "audio" in items_by_modality: + mm_inputs["audio"] = items_by_modality["audio"] # A list of audios + if "video" in items_by_modality: + mm_inputs["video"] = items_by_modality["video"] # A list of videos + return mm_inputs + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + raise NotImplementedError + + @abstractmethod + def parse_image_pil(self, image_pil: Image.Image) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_input_audio(self, input_audio: InputAudio) -> None: + raise NotImplementedError + + @abstractmethod + def parse_video(self, video_url: str) -> None: + raise NotImplementedError + + +class MultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + self._connector = MediaConnector( + media_io_kwargs=self._tracker._model_config.media_io_kwargs, + allowed_local_media_path=tracker.allowed_local_media_path, + ) + + def parse_image(self, image_url: str) -> None: + image = self._connector.fetch_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + if isinstance(image_embeds, dict): + embeds = { + k: self._connector.fetch_image_embedding(v) + for k, v in image_embeds.items() + } + placeholder = self._tracker.add("image_embeds", embeds) + + if isinstance(image_embeds, str): + embedding = self._connector.fetch_image_embedding(image_embeds) + placeholder = self._tracker.add("image_embeds", embedding) + + self._add_placeholder(placeholder) + + def parse_image_pil(self, image_pil: Image.Image) -> None: + placeholder = self._tracker.add("image", image_pil) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = self._connector.fetch_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + + return self.parse_audio(audio_url) + + def parse_video(self, video_url: str) -> None: + video = self._connector.fetch_video(video_url=video_url) + + placeholder = self._tracker.add("video", video) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + self._connector = MediaConnector( + media_io_kwargs=self._tracker._model_config.media_io_kwargs, + allowed_local_media_path=tracker.allowed_local_media_path + ) + + def parse_image(self, image_url: str) -> None: + image_coro = self._connector.fetch_image_async(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_image_embeds(self, + image_embeds: Union[str, dict[str, str]]) -> None: + future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() + + if isinstance(image_embeds, dict): + embeds = { + k: self._connector.fetch_image_embedding(v) + for k, v in image_embeds.items() + } + future.set_result(embeds) + + if isinstance(image_embeds, str): + embedding = self._connector.\ + fetch_image_embedding(image_embeds) + future.set_result(embedding) + + placeholder = self._tracker.add("image_embeds", future) + self._add_placeholder(placeholder) + + def parse_image_pil(self, image_pil: Image.Image) -> None: + future: asyncio.Future[Image.Image] = asyncio.Future() + future.set_result(image_pil) + + placeholder = self._tracker.add("image", future) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = self._connector.fetch_audio_async(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) + + def parse_input_audio(self, input_audio: InputAudio) -> None: + audio_data = input_audio.get("data", "") + audio_format = input_audio.get("format", "") + audio_url = f"data:audio/{audio_format};base64,{audio_data}" + + return self.parse_audio(audio_url) + + def parse_video(self, video_url: str) -> None: + video = self._connector.fetch_video_async(video_url=video_url) + + placeholder = self._tracker.add("video", video) + self._add_placeholder(placeholder) + + +def validate_chat_template(chat_template: Optional[Union[Path, str]]): + """Raises if the provided chat template appears invalid.""" + if chat_template is None: + return + + elif isinstance(chat_template, Path) and not chat_template.exists(): + raise FileNotFoundError( + "the supplied chat template path doesn't exist") + + elif isinstance(chat_template, str): + JINJA_CHARS = "{}\n" + if not any(c in chat_template + for c in JINJA_CHARS) and not Path(chat_template).exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist!") + + else: + raise TypeError( + f"{type(chat_template)} is not a valid chat template type") + + +def _load_chat_template( + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: + if chat_template is None: + return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " + "from its value") + + return chat_template + + try: + with open(chat_template) as f: + return f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + return _load_chat_template(chat_template, is_literal=True) + + +_cached_load_chat_template = lru_cache(_load_chat_template) + + +def load_chat_template( + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: + return _cached_load_chat_template(chat_template, is_literal=is_literal) + + +# TODO: Let user specify how to insert multimodal tokens into prompt +# (similar to chat template) +def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], + text_prompt: str) -> str: + """Combine multimodal prompts for a multimodal language model.""" + + # Look through the text prompt to check for missing placeholders + missing_placeholders: list[str] = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) + + +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) +_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam) +# Need to validate url objects +_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python +_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python +_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python + +_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] + +# Define a mapping from part types to their corresponding parsing functions. +MM_PARSER_MAP: dict[ + str, + Callable[[ChatCompletionContentPartParam], _ContentPart], +] = { + "text": + lambda part: _TextParser(part).get("text", None), + "image_url": + lambda part: _ImageParser(part).get("image_url", {}).get("url", None), + "image_embeds": + lambda part: _ImageEmbedsParser(part).get("image_embeds", None), + "image_pil": lambda part: _PILImageParser(part).get("image_pil", None), + "audio_url": + lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), + "input_audio": + lambda part: _InputAudioParser(part).get("input_audio", None), + "refusal": + lambda part: _RefusalParser(part).get("refusal", None), + "video_url": + lambda part: _VideoParser(part).get("video_url", {}).get("url", None), +} + + +def _parse_chat_message_content_mm_part( + part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: + """ + Parses a given multi-modal content part based on its type. + + Args: + part: A dict containing the content part, with a potential 'type' field. + + Returns: + A tuple (part_type, content) where: + - part_type: Type of the part (e.g., 'text', 'image_url'). + - content: Parsed content (e.g., text, image URL). + + Raises: + ValueError: If the 'type' field is missing and no direct URL is found. + """ + assert isinstance( + part, dict) # This is needed to avoid mypy errors: part.get() from str + part_type = part.get("type", None) + + if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + content = MM_PARSER_MAP[part_type](part) + + # Special case for 'image_url.detail' + # We only support 'auto', which is the default + if part_type == "image_url" and part.get("detail", "auto") != "auto": + logger.warning("'image_url.detail' is currently not supported " + "and will be ignored.") + + return part_type, content + + # Handle missing 'type' but provided direct URL fields. + # 'type' is required field by pydantic + if part_type is None: + if part.get("image_url") is not None: + image_params = cast(CustomChatCompletionContentSimpleImageParam, + part) + return "image_url", image_params.get("image_url", "") + if part.get("audio_url") is not None: + audio_params = cast(CustomChatCompletionContentSimpleAudioParam, + part) + return "audio_url", audio_params.get("audio_url", "") + if part.get("input_audio") is not None: + input_audio_params = cast(dict[str, str], part) + return "input_audio", input_audio_params + if part.get("video_url") is not None: + video_params = cast(CustomChatCompletionContentSimpleVideoParam, + part) + return "video_url", video_params.get("video_url", "") + # Raise an error if no 'type' or direct URL is found. + raise ValueError("Missing 'type' field in multimodal part.") + + if not isinstance(part_type, str): + raise ValueError("Invalid 'type' field in multimodal part.") + return part_type, "unknown part_type content" + + +VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", + "image_embeds", "image_pil", + "audio_url", "input_audio", "video_url") + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, + *, + wrap_dicts: bool, +) -> list[ConversationMessage]: + content = list[_ContentPart]() + + mm_parser = mm_tracker.create_parser() + + for part in parts: + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, + content=content)] # type: ignore + texts = cast(list[str], content) + text_prompt = "\n".join(texts) + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +def _parse_chat_message_content_part( + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + *, + wrap_dicts: bool, +) -> Optional[_ContentPart]: + """Parses a single part of a conversation. If wrap_dicts is True, + structured dictionary pieces for texts and images will be + wrapped in dictionaries, i.e., {"type": "text", "text", ...} and + {"type": "image"}, respectively. Otherwise multimodal data will be + handled by mm_parser, and texts will be returned as strings to be joined + with multimodal placeholders. + """ + if isinstance(part, str): # Handle plain text parts + return part + + # Handle structured dictionary parts + part_type, content = _parse_chat_message_content_mm_part(part) + + # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but + # content is None, log a warning and skip + if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: + logger.warning( + "Skipping multimodal part '%s' (type: '%s') " + "with empty / unparsable content.", part, part_type) + return None + + if part_type in ("text", "refusal"): + str_content = cast(str, content) + if wrap_dicts: + return {'type': 'text', 'text': str_content} + else: + return str_content + + if part_type == "image_pil": + image_content = cast(Image.Image, content) + mm_parser.parse_image_pil(image_content) + return {'type': 'image'} if wrap_dicts else None + if part_type == "image_url": + str_content = cast(str, content) + mm_parser.parse_image(str_content) + return {'type': 'image'} if wrap_dicts else None + if part_type == "image_embeds": + content = cast(Union[str, dict[str, str]], content) + mm_parser.parse_image_embeds(content) + return {'type': 'image'} if wrap_dicts else None + if part_type == "audio_url": + str_content = cast(str, content) + mm_parser.parse_audio(str_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "input_audio": + dict_content = cast(InputAudio, content) + mm_parser.parse_input_audio(dict_content) + return {'type': 'audio'} if wrap_dicts else None + + if part_type == "video_url": + str_content = cast(str, content) + mm_parser.parse_video(str_content) + return {'type': 'video'} if wrap_dicts else None + + raise NotImplementedError(f"Unknown part type: {part_type}") + + +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, +) -> list[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + wrap_dicts=(content_format == "openai"), + ) + + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + # The 'tool_calls' is not None check ensures compatibility. + # It's needed only if downstream code doesn't strictly + # follow the OpenAI spec. + if ("tool_calls" in parsed_msg + and parsed_msg["tool_calls"] is not None): + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _postprocess_messages(messages: list[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + +def parse_chat_messages( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, +) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: + conversation: list[ConversationMessage] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def parse_chat_messages_futures( + messages: list[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, +) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: list[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + content_format, + ) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +@deprecate_kwargs( + "trust_remote_code", + additional_message="Please use `model_config.trust_remote_code` instead.", +) +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + conversation: list[ConversationMessage], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + *, + model_config: ModelConfig, + tokenize: bool = False, # Different from HF's default + # Deprecated, explicitly capture here so it doesn't slit into kwargs. + trust_remote_code: Optional[bool] = None, + **kwargs: Any, +) -> str: + hf_chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + + if hf_chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + try: + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=hf_chat_template, + tokenize=tokenize, + **kwargs, + ) + + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `transformers` while applying chat template") + raise ValueError(str(e)) from e + +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: list[ChatCompletionMessageParam], + chat_template: Optional[str], + tools: Optional[list[dict[str, Any]]], + **kwargs: Any, +) -> list[int]: + from mistral_common.exceptions import MistralCommonException + + # The return value of resolve_mistral_chat_template is always None, + # and we won't use it. + resolve_mistral_chat_template( + chat_template=chat_template, + **kwargs, + ) + + try: + return tokenizer.apply_chat_template( + messages=messages, + tools=tools, + **kwargs, + ) + # mistral-common uses assert statements to stop processing of input + # if input does not comply with the expected format. + # We convert those assertion errors to ValueErrors so they can be + # are properly caught in the preprocessing_input step + except (AssertionError, MistralCommonException) as e: + raise ValueError(str(e)) from e + + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `mistral_common` while applying chat " + "template") + raise ValueError(str(e)) from e + +def random_tool_call_id() -> str: + return f"chatcmpl-tool-{random_uuid()}" diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py new file mode 100644 index 0000000..41671b5 --- /dev/null +++ b/vllm/entrypoints/cli/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand +from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand +from vllm.entrypoints.cli.benchmark.throughput import ( + BenchmarkThroughputSubcommand) + +__all__: list[str] = [ + "BenchmarkLatencySubcommand", + "BenchmarkServingSubcommand", + "BenchmarkThroughputSubcommand", +] \ No newline at end of file diff --git a/vllm/entrypoints/cli/benchmark/__init__.py b/vllm/entrypoints/cli/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/cli/benchmark/base.py b/vllm/entrypoints/cli/benchmark/base.py new file mode 100644 index 0000000..0c22bc7 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/base.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.entrypoints.cli.types import CLISubcommand + + +class BenchmarkSubcommandBase(CLISubcommand): + """ The base class of subcommands for vllm bench. """ + + help: str + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + """Add the CLI arguments to the parser.""" + raise NotImplementedError + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the benchmark. + + Args: + args: The arguments to the command. + """ + raise NotImplementedError diff --git a/vllm/entrypoints/cli/benchmark/latency.py b/vllm/entrypoints/cli/benchmark/latency.py new file mode 100644 index 0000000..3e68963 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/latency.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.benchmarks.latency import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase + + +class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): + """ The `latency` subcommand for vllm bench. """ + + name = "latency" + help = "Benchmark the latency of a single batch of requests." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py new file mode 100644 index 0000000..87fb9f3 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import typing + +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, + show_filtered_argument_or_group_from_help) + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + + +class BenchmarkSubcommand(CLISubcommand): + """ The `bench` subcommand for the vLLM CLI. """ + + name = "bench" + help = "vLLM bench subcommand." + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + args.dispatch_function(args) + + def validate(self, args: argparse.Namespace) -> None: + pass + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + bench_parser = subparsers.add_parser( + self.name, + help=self.help, + description=self.help, + usage="vllm bench [options]") + bench_subparsers = bench_parser.add_subparsers(required=True, + dest="bench_type") + + for cmd_cls in BenchmarkSubcommandBase.__subclasses__(): + cmd_subparser = bench_subparsers.add_parser( + cmd_cls.name, + help=cmd_cls.help, + description=cmd_cls.help, + usage=f"vllm bench {cmd_cls.name} [options]", + ) + cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd) + cmd_cls.add_cli_args(cmd_subparser) + show_filtered_argument_or_group_from_help(cmd_subparser, + ["bench", cmd_cls.name]) + cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG + return bench_parser + + +def cmd_init() -> list[CLISubcommand]: + return [BenchmarkSubcommand()] diff --git a/vllm/entrypoints/cli/benchmark/serve.py b/vllm/entrypoints/cli/benchmark/serve.py new file mode 100644 index 0000000..3dd7a46 --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/serve.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.benchmarks.serve import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase + + +class BenchmarkServingSubcommand(BenchmarkSubcommandBase): + """ The `serve` subcommand for vllm bench. """ + + name = "serve" + help = "Benchmark the online serving throughput." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm/entrypoints/cli/benchmark/throughput.py b/vllm/entrypoints/cli/benchmark/throughput.py new file mode 100644 index 0000000..d5d43ad --- /dev/null +++ b/vllm/entrypoints/cli/benchmark/throughput.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +from vllm.benchmarks.throughput import add_cli_args, main +from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase + + +class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): + """ The `throughput` subcommand for vllm bench. """ + + name = "throughput" + help = "Benchmark offline inference throughput." + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> None: + add_cli_args(parser) + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + main(args) diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py new file mode 100644 index 0000000..785c188 --- /dev/null +++ b/vllm/entrypoints/cli/collect_env.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import typing + +from vllm.collect_env import main as collect_env_main +from vllm.entrypoints.cli.types import CLISubcommand + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + + +class CollectEnvSubcommand(CLISubcommand): + """The `collect-env` subcommand for the vLLM CLI. """ + name = "collect-env" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Collect information about the environment.""" + collect_env_main() + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + return subparsers.add_parser( + "collect-env", + help="Start collecting environment information.", + description="Start collecting environment information.", + usage="vllm collect-env") + + +def cmd_init() -> list[CLISubcommand]: + return [CollectEnvSubcommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py new file mode 100644 index 0000000..3e09d45 --- /dev/null +++ b/vllm/entrypoints/cli/main.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +'''The CLI entrypoints of vLLM + +Note that all future modules must be lazily loaded within main +to avoid certain eager import breakage.''' +from __future__ import annotations + +import importlib.metadata +import signal +import sys + + +def register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def main(): + import vllm.entrypoints.cli.benchmark.main + import vllm.entrypoints.cli.collect_env + import vllm.entrypoints.cli.openai + import vllm.entrypoints.cli.run_batch + import vllm.entrypoints.cli.serve + from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup + from vllm.utils import FlexibleArgumentParser + + CMD_MODULES = [ + vllm.entrypoints.cli.openai, + vllm.entrypoints.cli.serve, + vllm.entrypoints.cli.benchmark.main, + vllm.entrypoints.cli.collect_env, + vllm.entrypoints.cli.run_batch, + ] + + cli_env_setup() + + parser = FlexibleArgumentParser( + description="vLLM CLI", + epilog=VLLM_SUBCMD_PARSER_EPILOG, + ) + parser.add_argument( + '-v', + '--version', + action='version', + version=importlib.metadata.version('vllm'), + ) + subparsers = parser.add_subparsers(required=False, dest="subparser") + cmds = {} + for cmd_module in CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(subparsers).set_defaults( + dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py new file mode 100644 index 0000000..5ddaee5 --- /dev/null +++ b/vllm/entrypoints/cli/openai.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import os +import signal +import sys +from typing import TYPE_CHECKING + +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from vllm.entrypoints.cli.types import CLISubcommand + +if TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + + +def _register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: + _register_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + return model_name, openai_client + + +def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: + conversation: list[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + return + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create(model=model_name, + messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + +def _add_query_options( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:8000/v1", + help="url of the running OpenAI-Compatible RESTful API server") + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " + "the first model in list models API call.")) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + )) + return parser + + +class ChatCommand(CLISubcommand): + """The `chat` subcommand for the vLLM CLI. """ + name = "chat" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + system_prompt = args.system_prompt + conversation: list[ChatCompletionMessageParam] = [] + + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + if args.quick: + conversation.append({"role": "user", "content": args.quick}) + + chat_completion = client.chat.completions.create( + model=model_name, messages=conversation) + print(chat_completion.choices[0].message.content) + return + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + return + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create( + model=model_name, messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + chat_parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server.", + description="Generate chat completions via the running API server.", + usage="vllm chat [options]") + _add_query_options(chat_parser) + chat_parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=("The system prompt to be added to the chat template, " + "used for models that support system prompts.")) + chat_parser.add_argument("-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE " + "and print the response, then exit.")) + return chat_parser + + +class CompleteCommand(CLISubcommand): + """The `complete` subcommand for the vLLM CLI. """ + name = 'complete' + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + + if args.quick: + completion = client.completions.create(model=model_name, + prompt=args.quick) + print(completion.choices[0].text) + return + + print("Please enter prompt to complete:") + while True: + input_prompt = input("> ") + completion = client.completions.create(model=model_name, + prompt=input_prompt) + output = completion.choices[0].text + print(output) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + complete_parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server."), + description=("Generate text completions based on the given prompt " + "via the running API server."), + usage="vllm complete [options]") + _add_query_options(complete_parser) + complete_parser.add_argument( + "-q", + "--quick", + type=str, + metavar="PROMPT", + help= + "Send a single prompt and print the completion output, then exit.") + return complete_parser + + +def cmd_init() -> list[CLISubcommand]: + return [ChatCommand(), CompleteCommand()] diff --git a/vllm/entrypoints/cli/run_batch.py b/vllm/entrypoints/cli/run_batch.py new file mode 100644 index 0000000..8649167 --- /dev/null +++ b/vllm/entrypoints/cli/run_batch.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import asyncio +import importlib.metadata +import typing + +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, + show_filtered_argument_or_group_from_help) +from vllm.logger import init_logger + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +class RunBatchSubcommand(CLISubcommand): + """The `run-batch` subcommand for vLLM CLI.""" + name = "run-batch" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + from vllm.entrypoints.openai.run_batch import main as run_batch_main + + logger.info("vLLM batch processing API version %s", + importlib.metadata.version("vllm")) + logger.info("args: %s", args) + + # Start the Prometheus metrics server. + # LLMEngine uses the Prometheus client + # to publish metrics at the /metrics endpoint. + if args.enable_metrics: + from prometheus_client import start_http_server + + logger.info("Prometheus metrics enabled") + start_http_server(port=args.port, addr=args.url) + else: + logger.info("Prometheus metrics disabled") + + asyncio.run(run_batch_main(args)) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + from vllm.entrypoints.openai.run_batch import make_arg_parser + + run_batch_parser = subparsers.add_parser( + "run-batch", + help="Run batch prompts and write results to file.", + description=( + "Run batch prompts using vLLM's OpenAI-compatible API.\n" + "Supports local or HTTP input/output files."), + usage= + "vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model ", + ) + run_batch_parser = make_arg_parser(run_batch_parser) + show_filtered_argument_or_group_from_help(run_batch_parser, + ["run-batch"]) + run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG + return run_batch_parser + + +def cmd_init() -> list[CLISubcommand]: + return [RunBatchSubcommand()] diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py new file mode 100644 index 0000000..9e24b31 --- /dev/null +++ b/vllm/entrypoints/cli/serve.py @@ -0,0 +1,265 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import os +import signal +import sys +from typing import Optional + +import uvloop + +import vllm +import vllm.envs as envs +from vllm.entrypoints.cli.types import CLISubcommand +from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, + setup_server) +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, + show_filtered_argument_or_group_from_help) +from vllm.executor.multiproc_worker_utils import _add_prefix +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.v1.engine.core import EngineCoreProc +from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines +from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus +from vllm.v1.utils import (APIServerProcessManager, + wait_for_completion_or_failure) + +logger = init_logger(__name__) + + +class ServeSubcommand(CLISubcommand): + """The `serve` subcommand for the vLLM CLI. """ + name = "serve" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + # If model is specified in CLI (as positional arg), it takes precedence + if hasattr(args, 'model_tag') and args.model_tag is not None: + args.model = args.model_tag + + if args.headless or args.api_server_count < 1: + run_headless(args) + else: + if args.data_parallel_start_rank: + raise ValueError("data_parallel_start_rank is only " + "applicable in headless mode") + if args.api_server_count > 1: + run_multi_api_server(args) + else: + # Single API server (this process). + uvloop.run(run_server(args)) + + def validate(self, args: argparse.Namespace) -> None: + validate_parsed_serve_args(args) + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "serve", + help="Start the vLLM OpenAI Compatible API server.", + description="Start the vLLM OpenAI Compatible API server.", + usage="vllm serve [model_tag] [options]") + serve_parser.add_argument("model_tag", + type=str, + nargs='?', + help="The model tag to serve " + "(optional if specified in config)") + serve_parser.add_argument( + "--headless", + action='store_true', + default=False, + help="Run in headless mode. See multi-node data parallel " + "documentation for more details.") + serve_parser.add_argument( + '--data-parallel-start-rank', + '-dpr', + type=int, + default=0, + help='Starting data parallel rank for secondary nodes.') + serve_parser.add_argument('--api-server-count', + '-asc', + type=int, + default=1, + help='How many API server processes to run.') + serve_parser.add_argument( + "--config", + type=str, + default='', + required=False, + help="Read CLI options from a config file. " + "Must be a YAML with the following options: " + "https://docs.vllm.ai/en/latest/configuration/serve_args.html") + + serve_parser = make_arg_parser(serve_parser) + show_filtered_argument_or_group_from_help(serve_parser, ["serve"]) + serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG + return serve_parser + + +def cmd_init() -> list[CLISubcommand]: + return [ServeSubcommand()] + + +def run_headless(args: argparse.Namespace): + + if args.api_server_count > 1: + raise ValueError("api_server_count can't be set in headless mode") + + # Create the EngineConfig. + engine_args = vllm.AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + if not envs.VLLM_USE_V1: + raise ValueError("Headless mode is only supported for V1") + + if engine_args.data_parallel_rank is not None: + raise ValueError("data_parallel_rank is not applicable in " + "headless mode") + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + + if local_engine_count <= 0: + raise ValueError("data_parallel_size_local must be > 0 in " + "headless mode") + + host = parallel_config.data_parallel_master_ip + port = engine_args.data_parallel_rpc_port # add to config too + handshake_address = get_tcp_uri(host, port) + + # Catch SIGTERM and SIGINT to allow graceful shutdown. + def signal_handler(signum, frame): + logger.debug("Received %d signal.", signum) + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + logger.info( + "Launching %d data parallel engine(s) in headless mode, " + "with head node address %s.", local_engine_count, handshake_address) + + # Create the engines. + engine_manager = CoreEngineProcManager( + target_fn=EngineCoreProc.run_engine_core, + local_engine_count=local_engine_count, + start_index=args.data_parallel_start_rank, + local_start_index=0, + vllm_config=vllm_config, + local_client=False, + handshake_address=handshake_address, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + + try: + engine_manager.join_first() + finally: + logger.info("Shutting down.") + engine_manager.close() + + +def run_multi_api_server(args: argparse.Namespace): + + assert not args.headless + num_api_servers = args.api_server_count + assert num_api_servers > 0 + + if num_api_servers > 1: + setup_multiprocess_prometheus() + + listen_address, sock = setup_server(args) + + engine_args = vllm.AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + model_config = vllm_config.model_config + + if num_api_servers > 1: + if not envs.VLLM_USE_V1: + raise ValueError("api_server_count > 1 is only supported for V1") + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " + "with api_server_count > 1") + + if model_config.is_multimodal_model and not ( + model_config.disable_mm_preprocessor_cache): + logger.warning( + "Multi-model preprocessor cache will be disabled for" + " api_server_count > 1") + model_config.disable_mm_preprocessor_cache = True + + executor_class = Executor.get_class(vllm_config) + log_stats = not engine_args.disable_log_stats + + parallel_config = vllm_config.parallel_config + dp_rank = parallel_config.data_parallel_rank + external_dp_lb = parallel_config.data_parallel_external_lb + assert external_dp_lb or dp_rank == 0 + + api_server_manager: Optional[APIServerProcessManager] = None + + with launch_core_engines(vllm_config, executor_class, log_stats, + num_api_servers) as (local_engine_manager, + coordinator, addresses): + + # Construct common args for the APIServerProcessManager up-front. + api_server_manager_kwargs = dict( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=addresses.inputs, + output_addresses=addresses.outputs, + stats_update_address=coordinator.get_stats_publish_address() + if coordinator else None) + + # For dp ranks > 0 in external DP LB mode, we must delay the + # start of the API servers until the local engine is started + # (after the launcher context manager exits), + # since we get the front-end stats update address from the coordinator + # via the handshake with the local engine. + if dp_rank == 0 or not external_dp_lb: + # Start API servers using the manager. + api_server_manager = APIServerProcessManager( + **api_server_manager_kwargs) + + # Start API servers now if they weren't already started. + if api_server_manager is None: + api_server_manager_kwargs["stats_update_address"] = ( + addresses.frontend_stats_publish_address) + api_server_manager = APIServerProcessManager( + **api_server_manager_kwargs) + + # Wait for API servers + wait_for_completion_or_failure(api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator) + + +def run_api_server_worker_proc(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Entrypoint for individual API server worker processes.""" + + # Add process-specific prefix to stdout and stderr. + from multiprocessing import current_process + process_name = current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + uvloop.run( + run_server_worker(listen_address, sock, args, client_config, + **uvicorn_kwargs)) diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py new file mode 100644 index 0000000..b88f094 --- /dev/null +++ b/vllm/entrypoints/cli/types.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import argparse +import typing + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + + +class CLISubcommand: + """Base class for CLI argument handlers.""" + + name: str + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + raise NotImplementedError("Subclasses should implement this method") + + def validate(self, args: argparse.Namespace) -> None: + # No validation by default + pass + + def subparser_init( + self, + subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + raise NotImplementedError("Subclasses should implement this method") diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py new file mode 100644 index 0000000..8455031 --- /dev/null +++ b/vllm/entrypoints/launcher.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import signal +import socket +from http import HTTPStatus +from typing import Any, Optional + +import uvicorn +from fastapi import FastAPI, Request, Response + +from vllm import envs +from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.ssl import SSLCertRefresher +from vllm.logger import init_logger +from vllm.utils import find_process_using_port +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, + **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + config.load() + server = uvicorn.Server(config) + _add_shutdown_handlers(app, server) + + loop = asyncio.get_running_loop() + + watchdog_task = loop.create_task( + watchdog_loop(server, app.state.engine_client)) + server_task = loop.create_task( + server.serve(sockets=[sock] if sock else None)) + + ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + watchdog_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + port = uvicorn_kwargs["port"] + process = find_process_using_port(port) + if process is not None: + logger.debug( + "port %s is used by process %s launched with command:\n%s", + port, process, " ".join(process.cmdline())) + logger.info("Shutting down FastAPI HTTP server.") + + return server.shutdown() + finally: + watchdog_task.cancel() + + +async def watchdog_loop(server: uvicorn.Server, engine: EngineClient): + """ + # Watchdog task that runs in the background, checking + # for error state in the engine. Needed to trigger shutdown + # if an exception arises is StreamingResponse() generator. + """ + VLLM_WATCHDOG_TIME_S = 5.0 + while True: + await asyncio.sleep(VLLM_WATCHDOG_TIME_S) + terminate_if_errored(server, engine) + + +def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): + """ + See discussions here on shutting down a uvicorn server + https://github.com/encode/uvicorn/discussions/1103 + In this case we cannot await the server shutdown here + because handler must first return to close the connection + for this request. + """ + engine_errored = engine.errored and not engine.is_running + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored: + server.should_exit = True + + +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: + """ + VLLM V1 AsyncLLM catches exceptions and returns + only two types: EngineGenerateError and EngineDeadError. + + EngineGenerateError is raised by the per request generate() + method. This error could be request specific (and therefore + recoverable - e.g. if there is an error in input processing). + + EngineDeadError is raised by the background output_handler + method. This error is global and therefore not recoverable. + + We register these @app.exception_handlers to return nice + responses to the end user if they occur and shut down if needed. + See https://fastapi.tiangolo.com/tutorial/handling-errors/ + for more details on how exception handlers work. + + If an exception is encountered in a StreamingResponse + generator, the exception is not raised, since we already sent + a 200 status. Rather, we send an error message as the next chunk. + Since the exception is not raised, this means that the server + will not automatically shut down. Instead, we use the watchdog + background task for check for errored state. + """ + + @app.exception_handler(RuntimeError) + @app.exception_handler(AsyncEngineDeadError) + @app.exception_handler(MQEngineDeadError) + @app.exception_handler(EngineDeadError) + @app.exception_handler(EngineGenerateError) + async def runtime_exception_handler(request: Request, __): + terminate_if_errored( + server=server, + engine=request.app.state.engine_client, + ) + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py new file mode 100644 index 0000000..bf5ddb9 --- /dev/null +++ b/vllm/entrypoints/llm.py @@ -0,0 +1,1609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +import warnings +from collections.abc import Sequence +from contextlib import contextmanager +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, + cast, overload) + +import cloudpickle +import torch.nn as nn +from pydantic import ValidationError +from tqdm.auto import tqdm +from typing_extensions import TypeVar, deprecated + +from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, + create_sort_beams_key_function) +from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, + is_init_field) +from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, + TaskOption) +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages, + resolve_chat_template_content_format) +from vllm.entrypoints.score_utils import (_cosine_similarity, + _validate_score_input_lens) +from vllm.entrypoints.utils import _validate_truncation_size +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest, LLMGuidedOptions) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, + PoolingRequestOutput, RequestOutput, + ScoringRequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + get_cached_tokenizer) +from vllm.usage.usage_lib import UsageContext + +from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of +import vllm.envs as envs +from vllm.zero_overhead.llm_engine import ZeroOverheadEngine + +if TYPE_CHECKING: + from vllm.v1.metrics.reader import Metric + +logger = init_logger(__name__) + +_R = TypeVar("_R", default=Any) + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Noting that `best_of` is only supported in V0. Otherwise, too small + values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + disable_custom_all_reduce: See + [ParallelConfig][vllm.config.ParallelConfig]. + disable_async_output_proc: Disable async output processing. + This may result in lower performance. + hf_token: The token to use as HTTP bearer authorization for remote files + . If `True`, will use the token generated when running + `huggingface-cli login` (stored in `~/.huggingface`). + hf_overrides: If a dictionary, contains arguments to be forwarded to the + HuggingFace config. If a callable, it is called to update the + HuggingFace config. + mm_processor_kwargs: Arguments to be forwarded to the model's processor + for multi-modal data, e.g., image processor. Overrides for the + multi-modal processor obtained from `AutoProcessor.from_pretrained`. + The available overrides depend on the model that is being run. + For example, for Phi-3-Vision: `{"num_crops": 4}`. + override_pooler_config: Initialize non-default pooling config or + override default pooling config for the pooling model. + e.g. `PoolerConfig(pooling_type="mean", normalize=False)`. + compilation_config: Either an integer or a dictionary. If it is an + integer, it is used as the level of compilation optimization. If it + is a dictionary, it can specify the full compilation configuration. + **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. + + Note: + This class is intended to be used for offline inference. For online + serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. + """ + + DEPRECATE_LEGACY: ClassVar[bool] = True + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + + def __init__( + self, + model: str, + *, + task: TaskOption = "auto", + tokenizer: Optional[str] = None, + #need change mode as "cpm" for 9g tokenizer + # tokenizer_mode: TokenizerMode = "cpm", + tokenizer_mode: TokenizerMode = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: ModelDType = "auto", + quantization: Optional[QuantizationMethods] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: Optional[int] = None, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_token: Optional[Union[bool, str]] = None, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, dict[str, Any], + CompilationConfig]] = None, + **kwargs, + ) -> None: + """LLM constructor.""" + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if "worker_cls" in kwargs: + worker_cls = kwargs["worker_cls"] + # if the worker_cls is not qualified string name, + # we serialize it using cloudpickle to avoid pickling issues + if isinstance(worker_cls, type): + kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + + if "kv_transfer_config" in kwargs and isinstance( + kwargs["kv_transfer_config"], dict): + from vllm.config import KVTransferConfig + raw_config_dict = kwargs["kv_transfer_config"] + try: + kwargs["kv_transfer_config"] = KVTransferConfig( + **raw_config_dict) + except ValidationError as e: + logger.error( + "Failed to convert 'kv_transfer_config' dict to " + "KVTransferConfig object. Dict: %s. Error: %s", + raw_config_dict, e) + # Consider re-raising a more specific vLLM error or ValueError + # to provide better context to the user. + raise ValueError( + f"Invalid 'kv_transfer_config' provided: {e}") from e + + if hf_overrides is None: + hf_overrides = {} + + if compilation_config is not None: + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + level=compilation_config) + elif isinstance(compilation_config, dict): + predicate = lambda x: is_init_field(CompilationConfig, x[0]) + compilation_config_instance = CompilationConfig( + **dict(filter(predicate, compilation_config.items()))) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = CompilationConfig() + + engine_args = EngineArgs( + model=model, + task=task, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_token=hf_token, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + + # Create the Engine (autoselects V0 vs V1) + if envs.VLLM_ZERO_OVERHEAD: + self.llm_engine = ZeroOverheadEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + else: + self.llm_engine = LLMEngine.from_engine_args( + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + self.engine_class = type(self.llm_engine) + + self.request_counter = Counter() + self.default_sampling_params: Union[dict[str, Any], None] = None + + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group().get_lora_tokenizer( + lora_request) + + def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: + tokenizer_group = self.llm_engine.get_tokenizer_group() + + # While CachedTokenizer is dynamic, have no choice but + # compare class name. Misjudgment will arise from + # user-defined tokenizer started with 'Cached' + if tokenizer.__class__.__name__.startswith("Cached"): + tokenizer_group.tokenizer = tokenizer + else: + tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + + def get_default_sampling_params(self) -> SamplingParams: + if self.default_sampling_params is None: + self.default_sampling_params = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + return SamplingParams.from_optional(**self.default_sampling_params) + return SamplingParams() + + @overload + def generate( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + *, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + prompt_token_ids: Optional[list[int]] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: list[str], + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + prompt_token_ids: Optional[list[list[int]]] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + *, + prompt_token_ids: list[int], + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: Optional[list[str]] = None, + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + *, + prompt_token_ids: list[list[int]], + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[list[int], list[list[int]]], + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + ) -> list[RequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def generate( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, list[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + priority: Optional[list[int]] = None, + ) -> list[RequestOutput]: + """Generates the completions for the input prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See [PromptType][vllm.inputs.PromptType] + for more details about the format of each prompts. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. + + Returns: + A list of `RequestOutput` objects containing the + generated completions in the same order as the input prompts. + + Note: + Using `prompts` and `prompt_token_ids` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the `inputs` parameter. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type not in ["generate", "transcription"]: + messages = [ + "LLM.generate() is only supported for (conditional) generation " + "models (XForCausalLM, XForConditionalGeneration).", + ] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "generate" in supported_runner_types: + messages.append( + "Your model supports the 'generate' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task generate`.") + + raise ValueError(" ".join(messages)) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, list[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + + if sampling_params is None: + # Use default sampling params. + sampling_params = self.get_default_sampling_params() + + tokenization_kwargs: dict[str, Any] = {} + truncate_prompt_tokens = None + if isinstance(sampling_params, SamplingParams): + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request, + tokenization_kwargs=tokenization_kwargs, + priority=priority, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, RequestOutput) + + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + + return self.llm_engine.collective_rpc(method, timeout, args, kwargs) + + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + """ + Run a function directly on the model inside each worker, + returning the result for each of them. + """ + executor = self.llm_engine.model_executor + return executor.apply_model(func) + + def _get_beam_search_lora_requests( + self, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]], + prompts: list[Union[TokensPrompt, TextPrompt]], + ) -> list[Optional[LoRARequest]]: + """Get the optional lora request corresponding to each prompt.""" + if isinstance(lora_request, + Sequence) and len(lora_request) != len(prompts): + raise ValueError( + "Lora request list should be the same length as the prompts") + + if lora_request is None or isinstance(lora_request, LoRARequest): + return [lora_request] * len(prompts) + + raise TypeError(f"Invalid lora_request type {type(lora_request)}") + + def beam_search( + self, + prompts: list[Union[TokensPrompt, TextPrompt]], + params: BeamSearchParams, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + use_tqdm: bool = False, + ) -> list[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + params: The beam search parameters. + lora_request: LoRA request to use for generation, if any. + use_tqdm: Whether to use tqdm to display the progress bar. + """ + # TODO: how does beam search work together with length penalty, + # frequency, penalty, and stopping criteria, etc.? + beam_width = params.beam_width + max_tokens = params.max_tokens + temperature = params.temperature + ignore_eos = params.ignore_eos + length_penalty = params.length_penalty + + lora_requests = self._get_beam_search_lora_requests( + lora_request, prompts) + + tokenizer = self.get_tokenizer() + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, + length_penalty, + ) + + def create_tokens_prompt_from_beam( + beam: BeamSearchSequence) -> TokensPrompt: + token_prompt_kwargs: TokensPrompt = { + "prompt_token_ids": beam.tokens + } + if beam.multi_modal_data is not None: + token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data + + if beam.mm_processor_kwargs is not None: + token_prompt_kwargs[ + "mm_processor_kwargs"] = beam.mm_processor_kwargs + return TokensPrompt(**token_prompt_kwargs) + + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + instances: list[BeamSearchInstance] = [] + + for lora_req, prompt in zip(lora_requests, prompts): + # Add multimodal processor kwargs & data + mm_kwargs = {} + if "multi_modal_data" in prompt: + mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] + if "mm_processor_kwargs" in prompt: + mm_kwargs["mm_processor_kwargs"] = prompt[ + "mm_processor_kwargs"] + + if "prompt_token_ids" in prompt: + prompt = cast(TokensPrompt, prompt) # Needed for mypy + prompt_tokens = prompt["prompt_token_ids"] + else: + prompt_tokens = tokenizer.encode(prompt["prompt"]) + + instances.append( + BeamSearchInstance( + prompt_tokens, + lora_request=lora_req, + logprobs=None, + **mm_kwargs, + ), ) + + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm(token_iter, + desc="Beam search", + unit="token", + unit_scale=False) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress.") + + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + # create the corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[(create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams]) + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=sort_beams_key, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + + def chat( + self, + messages: Union[list[ChatCompletionMessageParam], + list[list[ChatCompletionMessageParam]]], + sampling_params: Optional[Union[SamplingParams, + list[SamplingParams]]] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[list[dict[str, Any]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[RequestOutput]: + """ + Generate responses for a chat conversation. + + The chat conversation is converted into a text prompt using the + tokenizer and calls the [generate][] method to generate the + responses. + + Multi-modal inputs can be passed in the same way you would pass them + to the OpenAI API. + + Args: + messages: A list of conversations or a single conversation. + + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + chat_template: The template to use for structuring the chat. + If not provided, the model's default chat template will be used. + chat_template_content_format: The format to render message content. + + - "string" will render the content as a string. + Example: `"Who are you?"` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: `[{"type": "text", "text": "Who are you?"}]` + + add_generation_prompt: If True, adds a generation template + to each message. + continue_final_message: If True, continues the final message in + the conversation instead of starting a new one. Cannot be + `True` if `add_generation_prompt` is also `True`. + chat_template_kwargs: Additional kwargs to pass to the chat + template. + mm_processor_kwargs: Multimodal processor kwarg overrides for this + chat request. Only used for offline requests. + + Returns: + A list of `RequestOutput` objects containing the generated + responses in the same order as the input messages. + """ + list_of_messages: list[list[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], + messages) + else: + # messages is list[...] + list_of_messages = [ + cast(list[ChatCompletionMessageParam], messages) + ] + + tokenizer = self.get_tokenizer(lora_request) + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tools, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + prompts: list[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + if isinstance(tokenizer, MistralTokenizer): + prompt_token_ids = apply_mistral_chat_template( + tokenizer, + messages=msgs, + **_chat_template_kwargs, + ) + else: + prompt_str = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return self.generate( + prompts, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + @overload + def encode( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[list[int]] = None, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: list[str], + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[list[list[int]]] = None, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: list[int], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: Optional[list[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: list[list[int]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts'") + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[list[int], list[list[int]]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def encode( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, list[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[PoolingRequestOutput]: + """Apply pooling to the hidden states corresponding to the input + prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See [PromptType][vllm.inputs.PromptType] + for more details about the format of each prompts. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of `PoolingRequestOutput` objects containing the + pooled hidden states in the same order as the input prompts. + + Note: + Using `prompts` and `prompt_token_ids` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the `inputs` parameter. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.encode() is only supported for pooling models."] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: + messages.append( + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") + + raise ValueError(" ".join(messages)) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, list[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + elif isinstance(pooling_params, PoolingParams): + pooling_params.verify(self.llm_engine.model_config) + else: + for pooling_param in pooling_params: + pooling_param.verify(self.llm_engine.model_config) + + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + def embed( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[EmbeddingRequestOutput]: + """ + Generate an embedding vector for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See [PromptType][vllm.inputs.PromptType] + for more details about the format of each prompts. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of `EmbeddingRequestOutput` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "embed": + raise ValueError( + "Embedding API is only enabled for `--task embed`") + + items = self.encode(prompts, + truncate_prompt_tokens=truncate_prompt_tokens, + use_tqdm=use_tqdm, + pooling_params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [EmbeddingRequestOutput.from_base(item) for item in items] + + def classify( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ClassificationRequestOutput]: + """ + Generate class logits for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See [PromptType][vllm.inputs.PromptType] + for more details about the format of each prompts. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of `ClassificationRequestOutput` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "classify": + raise ValueError( + "Classification API is only enabled for `--task classify`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [ClassificationRequestOutput.from_base(item) for item in items] + + def _embedding_score( + self, + tokenizer: AnyTokenizer, + text_1: list[Union[str, TextPrompt, TokensPrompt]], + text_2: list[Union[str, TextPrompt, TokensPrompt]], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ScoringRequestOutput]: + + encoded_output: list[PoolingRequestOutput] = self.encode( + text_1 + text_2, + truncate_prompt_tokens=truncate_prompt_tokens, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + encoded_output_1: list[PoolingRequestOutput] = encoded_output[ + 0:len(text_1)] + encoded_output_2: list[PoolingRequestOutput] = encoded_output[ + len(text_1):] + + if len(encoded_output_1) == 1: + encoded_output_1 = encoded_output_1 * len(encoded_output_2) + + scores = _cosine_similarity(tokenizer=tokenizer, + embed_1=encoded_output_1, + embed_2=encoded_output_2) + + items = self.engine_class.validate_outputs(scores, + PoolingRequestOutput) + return [ScoringRequestOutput.from_base(item) for item in items] + + def _cross_encoding_score( + self, + tokenizer: AnyTokenizer, + text_1: list[str], + text_2: list[str], + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ScoringRequestOutput]: + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "Score API is only enabled for `--task embed or score`") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + pooling_params = PoolingParams(use_cross_encoder=True) + + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.llm_engine.model_config.max_model_len, + truncate_prompt_tokens, tokenization_kwargs) + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + items = self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + return [ScoringRequestOutput.from_base(item) for item in items] + + def score( + self, + text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], + text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + /, + *, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> list[ScoringRequestOutput]: + """Generate similarity scores for all pairs ``. + + The inputs can be `1 -> 1`, `1 -> N` or `N -> N`. + In the `1 - N` case the `text_1` sentence will be replicated `N` + times to pair with the `text_2` sentences. + The input pairs are used to build a list of prompts for the + cross encoder model. This class automatically batches the prompts, + considering the memory constraint. For the best performance, put all + of your texts into a single list and pass it to this method. + + Args: + text_1: can be a single prompt or a list of prompts, in which + case it has to have the same length as the `text_2` list + text_2: The texts to pair with the query to form the input + to the LLM. See [PromptType][vllm.inputs.PromptType] for + more details about the format of each prompts. + use_tqdm: If `True`, shows a tqdm progress bar. + If a callable (e.g., `functools.partial(tqdm, leave=False)`), + it is used to create the progress bar. + If `False`, no progress bar is created. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of `ScoringRequestOutput` objects containing the + generated scores in the same order as the input prompts. + """ + runner_type = self.llm_engine.model_config.runner_type + if runner_type != "pooling": + messages = ["LLM.score() is only supported for pooling models."] + + supported_runner_types = self.llm_engine.model_config \ + .supported_runner_types + if "pooling" in supported_runner_types: + messages.append( + "Your model supports the 'pooling' runner, but is " + f"currently initialized for the '{runner_type}' runner. " + "Please initialize vLLM using `--task embed`, " + "`--task classify`, `--task score` etc.") + + raise ValueError(" ".join(messages)) + + if self.llm_engine.model_config.task not in ("embed", "classify"): + raise ValueError("Score API is only enabled for " + "`--task embed or --task classify`.") + + if (self.llm_engine.model_config.task == "classify" + and self.llm_engine.model_config.hf_config.num_labels != 1): + raise ValueError("Score API is only enabled for num_labels == 1.") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + tokenizer = self.get_tokenizer() + + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for scoring") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + input_text_1: list[str] = [ensure_str(t) for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + input_text_2: list[str] = [ensure_str(t) for t in text_2] + + _validate_score_input_lens(input_text_1, input_text_2) + + if self.llm_engine.model_config.is_cross_encoder: + return self._cross_encoding_score(tokenizer, input_text_1, + input_text_2, + truncate_prompt_tokens, use_tqdm, + lora_request, + prompt_adapter_request) + else: + return self._embedding_score( + tokenizer, + input_text_1, # type: ignore[arg-type] + input_text_2, # type: ignore[arg-type] + truncate_prompt_tokens, + use_tqdm, + lora_request, + prompt_adapter_request) + + def start_profile(self) -> None: + self.llm_engine.start_profile() + + def stop_profile(self) -> None: + self.llm_engine.stop_profile() + + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: + return self.llm_engine.reset_prefix_cache(device) + + def sleep(self, level: int = 1): + """ + Put the engine to sleep. The engine should not process any requests. + The caller should guarantee that no requests are being processed + during the sleep period, before `wake_up` is called. + + Args: + level: The sleep level. Level 1 sleep will offload the model + weights and discard the kv cache. The content of kv cache + is forgotten. Level 1 sleep is good for sleeping and waking + up the engine to run the same model again. The model weights + are backed up in CPU memory. Please make sure there's enough + CPU memory to store the model weights. Level 2 sleep will + discard both the model weights and the kv cache. The content + of both the model weights and kv cache is forgotten. Level 2 + sleep is good for sleeping and waking up the engine to run a + different model or update the model, where previous model + weights are not needed. It reduces CPU memory pressure. + """ + self.reset_prefix_cache() + self.llm_engine.sleep(level=level) + + def wake_up(self, tags: Optional[list[str]] = None): + """ + Wake up the engine from sleep mode. See the [sleep][] method + for more details. + + Args: + tags: An optional list of tags to reallocate the engine memory + for specific memory allocations. Values must be in + `("weights", "kv_cache")`. If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + engine is used again. + """ + self.llm_engine.wake_up(tags) + + def get_metrics(self) -> list["Metric"]: + """Return a snapshot of aggregated metrics from Prometheus. + + Returns: + A ``MetricSnapshot`` instance capturing the current state + of all aggregated metrics from Prometheus. + + Note: + This method is only available with the V1 LLM engine. + """ + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + assert isinstance(self.llm_engine, V1LLMEngine) + return self.llm_engine.get_metrics() + + # LEGACY + def _convert_v1_inputs( + self, + prompts: Optional[Union[str, list[str]]], + prompt_token_ids: Optional[Union[list[int], list[list[int]]]], + ): + # skip_tokenizer_init is now checked in engine + + if prompts is None and prompt_token_ids is None: + raise ValueError( + "Either prompts or prompt_token_ids must be provided.") + if prompts is not None and prompt_token_ids is not None \ + and len(prompts) != len(prompt_token_ids): + raise ValueError( + "The lengths of prompts and prompt_token_ids must be the same." + ) + + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + if prompts is not None: + num_requests = len(prompts) + elif prompt_token_ids is not None: + num_requests = len(prompt_token_ids) + parsed_prompts: list[PromptType] = [] + for i in range(num_requests): + item: PromptType + + if prompts is not None: + item = TextPrompt(prompt=prompts[i]) + elif prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + parsed_prompts.append(item) + + return parsed_prompts + + def _validate_and_add_requests( + self, + prompts: Union[PromptType, Sequence[PromptType]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + *, + use_tqdm: Union[bool, Callable[..., tqdm]] = True, + lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[list[int]] = None, + ) -> None: + if guided_options is not None: + warnings.warn( + "guided_options_request is deprecated, use " + "SamplingParams.guided_decoding instead", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(prompts, (str, dict)): + # Convert a single prompt to a list. + prompts = [prompts] + + num_requests = len(prompts) + if isinstance(params, Sequence) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + if isinstance(lora_request, + Sequence) and len(lora_request) != num_requests: + raise ValueError("The lengths of prompts and lora_request " + "must be the same.") + + for sp in params if isinstance(params, Sequence) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_params(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY + + # Add requests to the engine. + it = prompts + if use_tqdm: + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + it = tqdm_func(it, desc="Adding requests") + + for i, prompt in enumerate(it): + self._add_request( + prompt, + params[i] if isinstance(params, Sequence) else params, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request[i] if isinstance( + lora_request, Sequence) else lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, + ) + + def _add_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request( + request_id, + prompt, + params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + def _add_guided_params( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options is None: + return params + + if params.guided_decoding is not None: + raise ValueError("Cannot set both guided_options_request and " + "params.guided_decoding.") + + params.guided_decoding = GuidedDecodingParams( + json=guided_options.guided_json, + regex=guided_options.guided_regex, + choice=guided_options.guided_choice, + grammar=guided_options.guided_grammar, + json_object=guided_options.guided_json_object, + backend=guided_options.guided_decoding_backend, + whitespace_pattern=guided_options.guided_whitespace_pattern, + structural_tag=guided_options.structural_tag, + ) + return params + + def _run_engine( + self, + *, + use_tqdm: Union[bool, Callable[..., tqdm]] = True + ) -> list[Union[RequestOutput, PoolingRequestOutput]]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + tqdm_func = use_tqdm if callable(use_tqdm) else tqdm + pbar = tqdm_func( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: list[Union[RequestOutput, PoolingRequestOutput]] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + n = len(output.outputs) + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) * n + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum( + len(stp.token_ids) for stp in output.outputs) + out_spd = (total_out_toks / + pbar.format_dict["elapsed"]) + pbar.postfix = ( + f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s") + pbar.update(n) + else: + pbar.update(1) + if pbar.n == num_requests: + pbar.refresh() + + if use_tqdm: + pbar.close() + + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py new file mode 100644 index 0000000..f3aee18 --- /dev/null +++ b/vllm/entrypoints/logger.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams + +logger = init_logger(__name__) + + +class RequestLogger: + + def __init__(self, *, max_log_len: Optional[int]) -> None: + super().__init__() + + self.max_log_len = max_log_len + + def log_inputs( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if prompt is not None: + prompt = prompt[:max_log_len] + + if prompt_token_ids is not None: + prompt_token_ids = prompt_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "prompt_embeds shape: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, + prompt_embeds.shape if prompt_embeds is not None else None, + lora_request, prompt_adapter_request) diff --git a/vllm/entrypoints/openai/__init__.py b/vllm/entrypoints/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py new file mode 100644 index 0000000..6c0a95e --- /dev/null +++ b/vllm/entrypoints/openai/api_server.py @@ -0,0 +1,1495 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import atexit +import gc +import importlib +import inspect +import json +import multiprocessing +import os +import signal +import socket +import tempfile +import uuid +from argparse import Namespace +from collections.abc import AsyncIterator, Awaitable +from contextlib import asynccontextmanager +from functools import partial +from http import HTTPStatus +from typing import Annotated, Any, Optional + +import prometheus_client +import regex as re +import uvloop +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator +from starlette.concurrency import iterate_in_threadpool +from starlette.datastructures import URL, Headers, MutableHeaders, State +from starlette.routing import Mount +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from typing_extensions import assert_never + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template) +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import (log_non_default_args, + make_arg_parser, + validate_parsed_serve_args) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + LoadLoRAAdapterRequest, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingRequest, PoolingResponse, + RerankRequest, RerankResponse, + ScoreRequest, ScoreResponse, + TokenizeRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, + TranslationResponse, + UnloadLoRAAdapterRequest) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_classification import ( + ServingClassification) +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_score import ServingScores +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_transcription import ( + OpenAIServingTranscription, OpenAIServingTranslation) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, + with_cancellation) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.transformers_utils.tokenizer import MistralTokenizer +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, + is_valid_ipv6_address, set_ulimit) +from vllm.v1.metrics.prometheus import get_prometheus_registry +from vllm.version import __version__ as VLLM_VERSION + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + gc.collect() + gc.freeze() + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + + +@asynccontextmanager +async def build_async_engine_client( + args: Namespace, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: + + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing, + client_config) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, + client_config: Optional[dict[str, Any]] = None, +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + + # Create the EngineConfig (determines if we can use V1). + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + # V1 AsyncLLM. + if envs.VLLM_USE_V1: + if disable_frontend_multiprocessing: + logger.warning( + "V1 is enabled, but got --disable-frontend-multiprocessing. " + "To disable frontend multiprocessing, set VLLM_USE_V1=0.") + + from vllm.v1.engine.async_llm import AsyncLLM + async_llm: Optional[AsyncLLM] = None + client_index = client_config.pop( + "client_index") if client_config else 0 + try: + async_llm = AsyncLLM.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats, + client_addresses=client_config, + client_index=client_index) + + # Don't keep the dummy data in memory + await async_llm.reset_mm_cache() + + yield async_llm + finally: + if async_llm: + async_llm.shutdown() + + # V0 AsyncLLM. + elif (MQLLMEngineClient.is_unsupported_config(vllm_config) + or disable_frontend_multiprocessing): + + engine_client: Optional[EngineClient] = None + try: + engine_client = AsyncLLMEngine.from_vllm_config( + vllm_config=vllm_config, + usage_context=usage_context, + disable_log_requests=engine_args.disable_log_requests, + disable_log_stats=engine_args.disable_log_stats) + yield engine_client + finally: + if engine_client and hasattr(engine_client, "shutdown"): + engine_client.shutdown() + + # V0MQLLMEngine. + else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.debug("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + # Ensure we can serialize transformer config before spawning + maybe_register_config_serialize_by_value() + + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) + engine_process = context.Process( + target=run_mp_engine, + args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path, + engine_args.disable_log_stats, + engine_args.disable_log_requests, engine_alive)) + engine_process.start() + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start." + logger.info("Started engine process with PID %d", engine_pid) + + def _cleanup_ipc_path(): + socket_path = ipc_path.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + # Ensure we clean up the local IPC socket file on exit. + atexit.register(_cleanup_ipc_path) + + # Build RPCClient, which conforms to EngineClient Protocol. + build_client = partial(MQLLMEngineClient, ipc_path, vllm_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) + try: + while True: + try: + await mq_engine_client.setup() + break + except TimeoutError: + if (not engine_process.is_alive() + or not engine_alive.value): + raise RuntimeError( + "Engine process failed to start. See stack " + "trace for the root cause.") from None + + yield mq_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mq_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() + + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) + + +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + media_type = content_type.split(";", maxsplit=1)[0] + if media_type != "application/json": + raise RequestValidationError(errors=[ + "Unsupported Media Type: Only 'application/json' is allowed" + ]) + + +router = APIRouter() + + +class PrometheusResponse(Response): + media_type = prometheus_client.CONTENT_TYPE_LATEST + + +def mount_metrics(app: FastAPI): + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() + + # `response_class=PrometheusResponse` is needed to return an HTTP response + # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" + # instead of the default "application/json" which is incorrect. + # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + ], + registry=registry, + ).add().instrument(app).expose(app, response_class=PrometheusResponse) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + +def chat(request: Request) -> Optional[OpenAIServingChat]: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> Optional[OpenAIServingCompletion]: + return request.app.state.openai_serving_completion + + +def pooling(request: Request) -> Optional[OpenAIServingPooling]: + return request.app.state.openai_serving_pooling + + +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: + return request.app.state.openai_serving_embedding + + +def score(request: Request) -> Optional[ServingScores]: + return request.app.state.openai_serving_scores + + +def classify(request: Request) -> Optional[ServingClassification]: + return request.app.state.openai_serving_classification + + +def rerank(request: Request) -> Optional[ServingScores]: + return request.app.state.openai_serving_scores + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def transcription(request: Request) -> OpenAIServingTranscription: + return request.app.state.openai_serving_transcription + + +def translation(request: Request) -> OpenAIServingTranslation: + return request.app.state.openai_serving_translation + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.get("/load") +async def get_server_load_metrics(request: Request): + # This endpoint returns the current server load metrics. + # It tracks requests utilizing the GPU from the following routes: + # - /v1/chat/completions + # - /v1/completions + # - /v1/audio/transcriptions + # - /v1/embeddings + # - /pooling + # - /classify + # - /score + # - /v1/score + # - /rerank + # - /v1/rerank + # - /v2/rerank + return JSONResponse( + content={'server_load': request.app.state.server_load_metrics}) + + +@router.get("/ping", response_class=Response) +@router.post("/ping", response_class=Response) +async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + +@router.post("/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_IMPLEMENTED.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_tokenize(request, raw_request) + except NotImplementedError as e: + raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value, + detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_detokenize(request, raw_request) + except OverflowError as e: + raise RequestValidationError(errors=[str(e)]) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + handler = models(raw_request) + + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +@router.post("/v1/chat/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + } + }) +@with_cancellation +@load_aware_call +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") + + generator = await handler.create_chat_completion(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ChatCompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_completion(request: CompletionRequest, raw_request: Request): + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + try: + generator = await handler.create_completion(request, raw_request) + except OverflowError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/embeddings", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + handler = embedding(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + generator = await handler.create_embedding(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, EmbeddingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/pooling", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_pooling(request: PoolingRequest, raw_request: Request): + handler = pooling(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Pooling API") + + generator = await handler.create_pooling(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, PoolingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/classify", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def create_classify(request: ClassificationRequest, + raw_request: Request): + handler = classify(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Classification API") + + generator = await handler.create_classify(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ClassificationResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_score_v1(request: ScoreRequest, raw_request: Request): + logger.warning( + "To indicate that Score API is not part of standard OpenAI API, we " + "have moved it to `/score`. Please update your client accordingly.") + + return await create_score(request, raw_request) + + +@router.post("/v1/audio/transcriptions", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_transcriptions(raw_request: Request, + request: Annotated[TranscriptionRequest, + Form()]): + handler = transcription(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Transcriptions API") + + audio_data = await request.file.read() + generator = await handler.create_transcription(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranscriptionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/audio/translations", + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNPROCESSABLE_ENTITY.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def create_translations(request: Annotated[TranslationRequest, + Form()], + raw_request: Request): + handler = translation(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Translations API") + + audio_data = await request.file.read() + generator = await handler.create_translation(audio_data, request, + raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, TranslationResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +@load_aware_call +async def do_rerank(request: RerankRequest, raw_request: Request): + handler = rerank(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Rerank (Score) API") + generator = await handler.do_rerank(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, RerankResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/v1/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +async def do_rerank_v1(request: RerankRequest, raw_request: Request): + logger.warning_once( + "To indicate that the rerank API is not part of the standard OpenAI" + " API, we have located it at `/rerank`. Please update your client " + "accordingly. (Note: Conforms to JinaAI rerank API)") + + return await do_rerank(request, raw_request) + + +@router.post("/v2/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +async def do_rerank_v2(request: RerankRequest, raw_request: Request): + return await do_rerank(request, raw_request) + + +TASK_HANDLERS: dict[str, dict[str, tuple]] = { + "generate": { + "messages": (ChatCompletionRequest, create_chat_completion), + "default": (CompletionRequest, create_completion), + }, + "embed": { + "messages": (EmbeddingChatRequest, create_embedding), + "default": (EmbeddingCompletionRequest, create_embedding), + }, + "score": { + "default": (RerankRequest, do_rerank) + }, + "rerank": { + "default": (RerankRequest, do_rerank) + }, + "reward": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, + "classify": { + "messages": (PoolingChatRequest, create_pooling), + "default": (PoolingCompletionRequest, create_pooling), + }, +} + +if envs.VLLM_SERVER_DEV_MODE: + logger.warning("SECURITY WARNING: Development endpoints are enabled! " + "This should NOT be used in production!") + + @router.get("/server_info") + async def show_server_info(raw_request: Request): + server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} + return JSONResponse(content=server_info) + + @router.post("/reset_prefix_cache") + async def reset_prefix_cache(raw_request: Request): + """ + Reset the prefix cache. Note that we currently do not check if the + prefix cache is successfully reset in the API server. + """ + device = None + device_str = raw_request.query_params.get("device") + if device_str is not None: + device = Device[device_str.upper()] + logger.info("Resetting prefix cache with specific %s...", str(device)) + await engine_client(raw_request).reset_prefix_cache(device) + return Response(status_code=200) + + @router.post("/sleep") + async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + @router.post("/wake_up") + async def wake_up(raw_request: Request): + tags = raw_request.query_params.getlist("tags") + if tags == []: + # set to None to wake up all tags if no tags are provided + tags = None + logger.info("wake up the engine with tags: %s", tags) + await engine_client(raw_request).wake_up(tags) + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + @router.get("/is_sleeping") + async def is_sleeping(raw_request: Request): + logger.info("check whether the engine is sleeping") + is_sleeping = await engine_client(raw_request).is_sleeping() + return JSONResponse(content={"is_sleeping": is_sleeping}) + + +@router.post("/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +async def invocations(raw_request: Request): + """ + For SageMaker, routes requests to other handlers based on model `task`. + """ + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}") from e + + task = raw_request.app.state.task + + if task not in TASK_HANDLERS: + raise HTTPException( + status_code=400, + detail=f"Unsupported task: '{task}' for '/invocations'. " + f"Expected one of {set(TASK_HANDLERS.keys())}") + + handler_config = TASK_HANDLERS[task] + if "messages" in body: + request_model, handler = handler_config["messages"] + else: + request_model, handler = handler_config["default"] + + # this is required since we lose the FastAPI automatic casting + request = request_model.model_validate(body) + return await handler(request, raw_request) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter", + dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, + raw_request: Request): + handler = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter", + dependencies=[Depends(validate_json_request)]) + async def unload_lora_adapter(request: UnloadLoRAAdapterRequest, + raw_request: Request): + handler = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + +def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: + if not log_config_file: + return None + try: + with open(log_config_file) as f: + return json.load(f) + except Exception as e: + logger.warning("Failed to load log config from file %s: error %s", + log_config_file, e) + return None + + +class AuthenticationMiddleware: + """ + Pure ASGI middleware that authenticates each request by checking + if the Authorization header exists and equals "Bearer {api_key}". + + Notes + ----- + There are two cases in which authentication is skipped: + 1. The HTTP method is OPTIONS. + 2. The request path doesn't start with /v1 (e.g. /health). + """ + + def __init__(self, app: ASGIApp, api_token: str) -> None: + self.app = app + self.api_token = api_token + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", + "websocket") or scope["method"] == "OPTIONS": + # scope["type"] can be "lifespan" or "startup" for example, + # in which case we don't need to do anything + return self.app(scope, receive, send) + root_path = scope.get("root_path", "") + url_path = URL(scope=scope).path.removeprefix(root_path) + headers = Headers(scope=scope) + # Type narrow to satisfy mypy. + if url_path.startswith("/v1") and headers.get( + "Authorization") != f"Bearer {self.api_token}": + response = JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return response(scope, receive, send) + return self.app(scope, receive, send) + + +class XRequestIdMiddleware: + """ + Middleware the set's the X-Request-Id header for each response + to a random uuid4 (hex) value if the header isn't already + present in the request, otherwise use the provided request id. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket"): + return self.app(scope, receive, send) + + # Extract the request headers. + request_headers = Headers(scope=scope) + + async def send_with_request_id(message: Message) -> None: + """ + Custom send function to mutate the response headers + and append X-Request-Id to it. + """ + if message["type"] == "http.response.start": + response_headers = MutableHeaders(raw=message["headers"]) + request_id = request_headers.get("X-Request-Id", + uuid.uuid4().hex) + response_headers.append("X-Request-Id", request_id) + await send(message) + + return self.app(scope, receive, send_with_request_id) + + +def build_app(args: Namespace) -> FastAPI: + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + @app.exception_handler(HTTPException) + async def http_exception_handler(_: Request, exc: HTTPException): + err = ErrorResponse(message=exc.detail, + type=HTTPStatus(exc.status_code).phrase, + code=exc.status_code) + return JSONResponse(err.model_dump(), status_code=exc.status_code) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_: Request, + exc: RequestValidationError): + exc_str = str(exc) + errors_str = str(exc.errors()) + + if exc.errors() and errors_str and errors_str != exc_str: + message = f"{exc_str} {errors_str}" + else: + message = exc_str + + err = ErrorResponse(message=message, + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY + if token := args.api_key or envs.VLLM_API_KEY: + app.add_middleware(AuthenticationMiddleware, api_token=token) + + if args.enable_request_id_headers: + app.add_middleware(XRequestIdMiddleware) + + if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: + logger.warning("CAUTION: Enabling log response in the API Server. " + "This can include sensitive information and should be " + "avoided in production.") + + @app.middleware("http") + async def log_response(request: Request, call_next): + response = await call_next(request) + response_body = [ + section async for section in response.body_iterator + ] + response.body_iterator = iterate_in_threadpool(iter(response_body)) + logger.info("response_body={%s}", + response_body[0].decode() if response_body else None) + return response + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) # type: ignore[arg-type] + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") + + return app + + +async def init_app_state( + engine_client: EngineClient, + vllm_config: VllmConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + state.vllm_config = vllm_config + model_config = vllm_config.model_config + + resolved_chat_template = load_chat_template(args.chat_template) + if resolved_chat_template is not None: + # Get the tokenizer to check official template + tokenizer = await engine_client.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + # The warning is logged in resolve_mistral_chat_template. + resolved_chat_template = resolve_mistral_chat_template( + chat_template=resolved_chat_template) + else: + hf_chat_template = resolve_hf_chat_template( + tokenizer=tokenizer, + chat_template=None, + tools=None, + model_config=vllm_config.model_config, + ) + + if hf_chat_template != resolved_chat_template: + logger.warning( + "Using supplied chat template: %s\n" + "It is different from official chat template '%s'. " + "This discrepancy may lead to performance degradation.", + resolved_chat_template, args.model) + + state.openai_serving_models = OpenAIServingModels( + engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + await state.openai_serving_models.init_static_loras() + state.openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + expand_tools_even_if_tool_choice_none=args. + expand_tools_even_if_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) if model_config.runner_type == "generate" else None + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_force_include_usage=args.enable_force_include_usage, + ) if model_config.runner_type == "generate" else None + state.openai_serving_pooling = OpenAIServingPooling( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.runner_type == "pooling" else None + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.task == "embed" else None + state.openai_serving_classification = ServingClassification( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.task == "classify" else None + + enable_serving_reranking = (model_config.task == "classify" and getattr( + model_config.hf_config, "num_labels", 0) == 1) + state.jinaai_serving_reranking = ServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger) if enable_serving_reranking else None + state.openai_serving_scores = ServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger) if ( + model_config.task == "embed" or enable_serving_reranking) else None + + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) + state.openai_serving_transcription = OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None + state.openai_serving_translation = OpenAIServingTranslation( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) if model_config.runner_type == "transcription" else None + state.task = model_config.task + + state.enable_server_load_tracking = args.enable_server_load_tracking + state.server_load_metrics = 0 + + +def create_server_socket(addr: tuple[str, int]) -> socket.socket: + family = socket.AF_INET + if is_valid_ipv6_address(addr[0]): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind(addr) + + return sock + + +def validate_api_server_args(args): + valid_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valid_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})") + + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if args.reasoning_parser \ + and args.reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {args.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + + +def setup_server(args): + """Validate API server args, set up signal handler, create socket + ready to serve.""" + + logger.info("vLLM API server version %s", VLLM_VERSION) + log_non_default_args(args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + validate_api_server_args(args) + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + # workaround to avoid footguns where uvicorn drops requests with too + # many concurrent requests active + set_ulimit() + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + addr, port = sock_addr + is_ssl = args.ssl_keyfile and args.ssl_certfile + host_part = f"[{addr}]" if is_valid_ipv6_address( + addr) else addr or "0.0.0.0" + listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" + + return listen_address, sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + """Run a single-worker API server.""" + listen_address, sock = setup_server(args) + await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) + + +async def run_server_worker(listen_address, + sock, + args, + client_config=None, + **uvicorn_kwargs) -> None: + """Run a single API server worker.""" + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + server_index = client_config.get("client_index", 0) if client_config else 0 + + # Load logging config for uvicorn if specified + log_config = load_log_config(args.log_config_file) + if log_config is not None: + uvicorn_kwargs['log_config'] = log_config + + async with build_async_engine_client(args, client_config) as engine_client: + app = build_app(args) + + vllm_config = await engine_client.get_vllm_config() + await init_app_state(engine_client, vllm_config, app.state, args) + + logger.info("Starting vLLM API server %d on %s", server_index, + listen_address) + shutdown_task = await serve_http( + app, + sock=sock, + enable_ssl_refresh=args.enable_ssl_refresh, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # NOTE: When the 'disable_uvicorn_access_log' value is True, + # no access log will be output. + access_log=not args.disable_uvicorn_access_log, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + try: + await shutdown_task + finally: + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/entrypoints/cli/main.py for CLI + # entrypoints. + cli_env_setup() + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py new file mode 100644 index 0000000..4f8aaab --- /dev/null +++ b/vllm/entrypoints/openai/cli_args.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains the command line arguments for the vLLM's +OpenAI-compatible server. It is kept in a separate file for documentation +purposes. +""" + +import argparse +import json +import ssl +from collections.abc import Sequence +from typing import Optional, Union, get_args + +import vllm.envs as envs +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) +from vllm.entrypoints.openai.serving_models import (LoRAModulePath, + PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +class LoRAParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: list[LoRAModulePath] = [] + for item in values: + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) + setattr(namespace, self.dest, lora_list) + + +class PromptAdapterParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: list[PromptAdapterPath] = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--host", + type=optional_type(str), + default=None, + help="Host name.") + parser.add_argument("--port", type=int, default=8000, help="Port number.") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="Log level for uvicorn.") + parser.add_argument("--disable-uvicorn-access-log", + action="store_true", + help="Disable uvicorn access log.") + parser.add_argument("--allow-credentials", + action="store_true", + help="Allow credentials.") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="Allowed origins.") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="Allowed methods.") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="Allowed headers.") + parser.add_argument("--api-key", + type=optional_type(str), + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") + parser.add_argument( + "--lora-modules", + type=optional_type(str), + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): ``'name=path'`` " + "Example (new format): " + "``{\"name\": \"name\", \"path\": \"lora_path\", " + "\"base_model_name\": \"id\"}``") + parser.add_argument( + "--prompt-adapters", + type=optional_type(str), + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + parser.add_argument("--chat-template", + type=optional_type(str), + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model.") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: ``"Hello World"``\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: ``[{"type": "text", "text": "Hello world!"}]``') + parser.add_argument("--response-role", + type=optional_type(str), + default="assistant", + help="The role name to return if " + "``request.add_generation_prompt=true``.") + parser.add_argument("--ssl-keyfile", + type=optional_type(str), + default=None, + help="The file path to the SSL key file.") + parser.add_argument("--ssl-certfile", + type=optional_type(str), + default=None, + help="The file path to the SSL cert file.") + parser.add_argument("--ssl-ca-certs", + type=optional_type(str), + default=None, + help="The CA certificates file.") + parser.add_argument( + "--enable-ssl-refresh", + action="store_true", + default=False, + help="Refresh SSL Context when SSL certificate files change") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)." + ) + parser.add_argument( + "--root-path", + type=optional_type(str), + default=None, + help="FastAPI root_path when app is behind a path based routing proxy." + ) + parser.add_argument( + "--middleware", + type=optional_type(str), + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server " + "using ``@app.middleware('http')``. " + "If a class is provided, vLLM will add it to the server " + "using ``app.add_middleware()``. ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When ``--max-logprobs`` is specified, represents single tokens " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + parser.add_argument( + "--enable-request-id-headers", + action="store_true", + help="If specified, API server will add X-Request-Id header to " + "responses.") + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help="Enable auto tool choice for supported models. Use " + "``--tool-call-parser`` to specify which parser to use.") + parser.add_argument( + "--expand-tools-even-if-tool-choice-none", + action="store_true", + default=False, + deprecated=True, + help="Include tool definitions in prompts " + "even when tool_choice='none'. " + "This is a transitional option that will be removed in v0.10.0. " + "In v0.10.0, tool definitions will always be included regardless of " + "tool_choice setting. Use this flag now to test the new behavior " + "before the breaking change.") + + valid_tool_parsers = ToolParserManager.tool_parsers.keys() + parser.add_argument( + "--tool-call-parser", + type=str, + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for ``--enable-auto-tool-choice``.") + + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in ``--tool-call-parser``.") + + parser.add_argument( + "--log-config-file", + type=str, + default=envs.VLLM_LOGGING_CONFIG_PATH, + help="Path to logging config JSON file for both vllm and uvicorn", + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + ' The default of None means unlimited.') + + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." + ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + parser.add_argument( + "--enable-force-include-usage", + action='store_true', + default=False, + help="If set to True, including usage on every request.") + parser.add_argument( + "--enable-server-load-tracking", + action='store_true', + default=False, + help= + "If set to True, enable tracking server_load_metrics in the app state." + ) + + return parser + + +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + if args.enable_prompt_embeds and args.enable_prompt_adapter: + raise ValueError( + "Cannot use prompt embeds and prompt adapter at the same time.") + + +def log_non_default_args(args: argparse.Namespace): + non_default_args = {} + parser = make_arg_parser(FlexibleArgumentParser()) + for arg, default in vars(parser.parse_args([])).items(): + if default != getattr(args, arg): + non_default_args[arg] = getattr(args, arg) + logger.info("non-default args: %s", non_default_args) + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py new file mode 100644 index 0000000..29d7225 --- /dev/null +++ b/vllm/entrypoints/openai/logits_processors.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from functools import lru_cache, partial +from typing import Optional, Union + +import torch + +from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class AllowedTokenIdsLogitsProcessor: + """Logits processor for constraining generated tokens to a + specific set of token ids.""" + + def __init__(self, allowed_ids: Iterable[int]): + self.allowed_ids: Optional[list[int]] = list(allowed_ids) + self.mask: Optional[torch.Tensor] = None + + def __call__(self, token_ids: list[int], + logits: torch.Tensor) -> torch.Tensor: + if self.mask is None: + self.mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + self.mask[self.allowed_ids] = False + self.allowed_ids = None + logits.masked_fill_(self.mask, float("-inf")) + return logits + + +@lru_cache(maxsize=32) +def _get_allowed_token_ids_logits_processor( + allowed_token_ids: frozenset[int], + vocab_size: int, +) -> LogitsProcessor: + if not allowed_token_ids: + raise ValueError("Empty allowed_token_ids provided") + if not all(0 <= tid < vocab_size for tid in allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + + +def logit_bias_logits_processor( + logit_bias: dict[int, float], + token_ids: list[int], + logits: torch.Tensor, +) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + +def get_logits_processors( + logit_bias: Optional[Union[dict[int, float], dict[str, float]]], + allowed_token_ids: Optional[list[int]], + tokenizer: AnyTokenizer, +) -> list[LogitsProcessor]: + logits_processors: list[LogitsProcessor] = [] + if logit_bias: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + # Check if token_id is within the vocab size + for token_id, bias in clamped_logit_bias.items(): + if token_id < 0 or token_id >= len(tokenizer): + raise ValueError(f"token_id {token_id} in logit_bias contains " + "out-of-vocab token id") + + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) + + if allowed_token_ids is not None: + logits_processors.append( + _get_allowed_token_ids_logits_processor( + frozenset(allowed_token_ids), len(tokenizer))) + + return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py new file mode 100644 index 0000000..d4db238 --- /dev/null +++ b/vllm/entrypoints/openai/protocol.py @@ -0,0 +1,2096 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import json +import time +from http import HTTPStatus +from typing import Annotated, Any, ClassVar, Literal, Optional, Union + +import regex as re +import torch +from fastapi import HTTPException, UploadFile +from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, + ValidationInfo, field_validator, model_validator) +from typing_extensions import TypeAlias + +from vllm import envs +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + random_tool_call_id) +from vllm.logger import init_logger +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.sequence import Logprob +from vllm.utils import random_uuid, resolve_obj_by_qualname + +logger = init_logger(__name__) + +_LONG_INFO = torch.iinfo(torch.long) + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does allow extra fields + model_config = ConfigDict(extra="allow") + + # Cache class field names + field_names: ClassVar[Optional[set[str]]] = None + + @model_validator(mode="wrap") + @classmethod + def __log_extra_fields__(cls, data, handler): + result = handler(data) + if not isinstance(data, dict): + return result + field_names = cls.field_names + if field_names is None: + # Get all class field names and their potential aliases + field_names = set() + for field_name, field in cls.model_fields.items(): + field_names.add(field_name) + if alias := getattr(field, "alias", None): + field_names.add(alias) + cls.field_names = field_names + + # Compare against both field names and aliases + if any(k not in field_names for k in data): + logger.warning( + "The following fields were present in the request " + "but ignored: %s", + data.keys() - field_names, + ) + return result + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: list[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: list[ModelCard] = Field(default_factory=list) + + +class PromptTokenUsageInfo(OpenAIBaseModel): + cached_tokens: Optional[int] = None + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + prompt_tokens_details: Optional[PromptTokenUsageInfo] = None + + +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + +class StructuralTag(OpenAIBaseModel): + begin: str + # schema is the field, but that causes conflicts with pydantic so + # instead use structural_tag_schema with an alias + structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, + alias="schema") + end: str + + +class StructuralTagResponseFormat(OpenAIBaseModel): + type: Literal["structural_tag"] + structures: list[StructuralTag] + triggers: list[str] + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_schema", "json_object", or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat] + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = False + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +# extra="forbid" is a workaround to have kwargs as a field, +# see https://github.com/pydantic/pydantic/issues/3125 +class LogitsProcessorConstructor(BaseModel): + qualname: str + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None + + model_config = ConfigDict(extra="forbid") + + +LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] + + +def get_logits_processors(processors: Optional[LogitsProcessors], + pattern: Optional[str]) -> Optional[list[Any]]: + if processors and pattern: + logits_processors = [] + for processor in processors: + qualname = processor if isinstance(processor, + str) else processor.qualname + if not re.match(pattern, qualname): + raise ValueError( + f"Logits processor '{qualname}' is not allowed by this " + "server. See --logits-processor-pattern engine argument " + "for more information.") + try: + logits_processor = resolve_obj_by_qualname(qualname) + except Exception as e: + raise ValueError( + f"Logits processor '{qualname}' could not be resolved: {e}" + ) from e + if isinstance(processor, LogitsProcessorConstructor): + logits_processor = logits_processor(*processor.args or [], + **processor.kwargs or {}) + logits_processors.append(logits_processor) + return logits_processors + elif processors: + raise ValueError( + "The `logits_processors` argument is not supported by this " + "server. See --logits-processor-pattern engine argugment " + "for more information.") + return None + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: list[ChatCompletionMessageParam] + model: Optional[str] = None + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + max_tokens: Optional[int] = Field( + default=None, + deprecated= + 'max_tokens is deprecated in favor of the max_completion_tokens field') + max_completion_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[AnyResponseFormat] = None + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, list[str]]] = [] + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + tools: Optional[list[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[ + Literal["none"], + Literal["auto"], + Literal["required"], + ChatCompletionNamedToolChoiceParam, + ]] = "none" + + # NOTE this will be ignored by vLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False + user: Optional[str] = None + + # --8<-- [start:chat-completion-sampling-params] + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None + length_penalty: float = 1.0 + stop_token_ids: Optional[list[int]] = [] + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + prompt_logprobs: Optional[int] = None + allowed_token_ids: Optional[list[int]] = None + bad_words: list[str] = Field(default_factory=list) + # --8<-- [end:chat-completion-sampling-params] + + # --8<-- [start:chat-completion-extra-params] + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[list[dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[list[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + structural_tag: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the structural tag schema."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'"), + ) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response."), + ) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) + return_tokens_as_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified with 'logprobs', tokens are represented " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.")) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit). Not supported by vLLM engine V0.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") + + vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( + default=None, + description=("Additional request parameters with string or " + "numeric values, used by custom extensions."), + ) + + # --8<-- [end:chat-completion-extra-params] + + # Default sampling parameters for chat completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + } + + def to_beam_search_params( + self, max_tokens: int, + default_sampling_params: dict) -> BeamSearchParams: + + n = self.n if self.n is not None else 1 + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + + def to_sampling_params( + self, + max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: dict, + ) -> SamplingParams: + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + + guided_json_object = None + if self.response_format is not None: + if self.response_format.type == "json_object": + guided_json_object = True + elif self.response_format.type == "json_schema": + json_schema = self.response_format.json_schema + assert json_schema is not None + self.guided_json = json_schema.json_schema + elif self.response_format.type == "structural_tag": + structural_tag = self.response_format + assert structural_tag is not None and isinstance( + structural_tag, StructuralTagResponseFormat) + s_tag_obj = structural_tag.model_dump(by_alias=True) + self.structural_tag = json.dumps(s_tag_obj) + + guided_decoding = GuidedDecodingParams.from_optional( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern, + structural_tag=self.structural_tag, + ) + + extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} + if self.kv_transfer_params: + # Pass in kv_transfer_params via extra_args + extra_args["kv_transfer_params"] = self.kv_transfer_params + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=prompt_logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens, + min_tokens=self.min_tokens, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + bad_words= self.bad_words, + allowed_token_ids=self.allowed_token_ids, + extra_args=extra_args or None, + ) + + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + if self.tool_choice == "required": + # Pydantic schema generation cannot be used since the JSON schema + # has to be constructed for a specific instantiation of a tool list + # so that parameters of a function are correctly generated + # based on the chosen function name + def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: + return { + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name] + }, + # parameters are always generated as '{}' in the final + # output if they are missing from the request + # (i.e. are None or '{}') so the schema is + # updated to produce an empty object in that case + "parameters": tool.function.parameters + if tool.function.parameters else { + "type": "object", + "properties": {} + } + }, + "required": ["name", "parameters"] + } + + json_schema = { + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": [get_tool_schema(tool) for tool in self.tools] + } + } + return json_schema + + return None + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if top_logprobs > 0 and not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and data.get("tool_choice", "none") not in ( + "none", + "auto", + "required", + ): + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and data.get("tools"): + data["tool_choice"] = "auto" + + # if "tool_choice" is "none" -- no validation is needed for tools + if "tool_choice" in data and data["tool_choice"] == "none": + return data + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" or "required" + if data["tool_choice"] not in [ + "auto", "required" + ] and not isinstance(data["tool_choice"], dict): + raise NotImplementedError( + f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ + 'Only named tools, "none", "auto" or "required" '\ + 'are supported.' + ) + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + correct_usage_message = 'Correct usage: `{"type": "function",' \ + ' "function": {"name": "my_function"}}`' + if isinstance(data["tool_choice"], dict): + valid_tool = False + function = data["tool_choice"].get("function") + if not isinstance(function, dict): + raise ValueError( + f"Invalid value for `function`: `{function}` in " + f"`tool_choice`! {correct_usage_message}") + if "name" not in function: + raise ValueError(f"Expected field `name` in `function` in " + f"`tool_choice`! {correct_usage_message}") + function_name = function["name"] + if not isinstance(function_name, + str) or len(function_name) == 0: + raise ValueError( + f"Invalid `name` in `function`: `{function_name}`" + f" in `tool_choice`! {correct_usage_message}") + for tool in data["tools"]: + if tool["function"]["name"] == function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") + return data + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + @model_validator(mode="before") + @classmethod + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: Optional[str] = None + prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, list[str]]] = [] + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + user: Optional[str] = None + + # --8<-- [start:completion-sampling-params] + use_beam_search: bool = False + top_k: Optional[int] = None + min_p: Optional[float] = None + repetition_penalty: Optional[float] = None + length_penalty: float = 1.0 + stop_token_ids: Optional[list[int]] = [] + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[list[int]] = None + prompt_logprobs: Optional[int] = None + # --8<-- [end:completion-sampling-params] + + # --8<-- [start:completion-extra-params] + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[AnyResponseFormat] = Field( + default=None, + description=( + "Similar to chat completion, this parameter specifies the format " + "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" + ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." + ), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description="If specified, the output will follow the JSON schema.", + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[list[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'"), + ) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) + + return_tokens_as_token_ids: Optional[bool] = Field( + default=None, + description=( + "If specified with 'logprobs', tokens are represented " + " as strings of the form 'token_id:{token_id}' so that tokens " + "that are not JSON-encodable can be identified.")) + + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.") + + vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( + default=None, + description=("Additional request parameters with string or " + "numeric values, used by custom extensions."), + ) + + # --8<-- [end:completion-extra-params] + + # Default sampling parameters for completion requests + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + } + + def to_beam_search_params( + self, + max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> BeamSearchParams: + + if default_sampling_params is None: + default_sampling_params = {} + n = self.n if self.n is not None else 1 + + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get("temperature", 1.0) + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + include_stop_str_in_output=self.include_stop_str_in_output, + ) + + def to_sampling_params( + self, + max_tokens: int, + logits_processor_pattern: Optional[str], + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + + if default_sampling_params is None: + default_sampling_params = {} + + # Default parameters + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + + echo_without_generation = self.echo and self.max_tokens == 0 + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern, + ) + + extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} + if self.kv_transfer_params: + # Pass in kv_transfer_params via extra_args + extra_args["kv_transfer_params"] = self.kv_transfer_params + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + prompt_logprobs=prompt_logprobs, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids, + extra_args=extra_args or None, + ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def validate_prompt_and_prompt_embeds(cls, data): + if data.get("prompt") is None and data.get("prompt_embeds") is None: + raise ValueError( + "At least one of `prompt` or `prompt_embeds` must be set.") + return data + + +class EmbeddingCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: Optional[str] = None + input: Union[list[int], list[list[int]], str, list[str]] + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + + # --8<-- [start:embedding-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:embedding-pooling-params] + + # --8<-- [start:embedding-extra-params] + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # --8<-- [end:embedding-extra-params] + + def to_pooling_params(self): + return PoolingParams(dimensions=self.dimensions, + additional_data=self.additional_data) + + +class EmbeddingChatRequest(OpenAIBaseModel): + model: Optional[str] = None + messages: list[ChatCompletionMessageParam] + + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + + # --8<-- [start:chat-embedding-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:chat-embedding-pooling-params] + + # --8<-- [start:chat-embedding-extra-params] + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + # --8<-- [end:chat-embedding-extra-params] + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + def to_pooling_params(self): + return PoolingParams(dimensions=self.dimensions, + additional_data=self.additional_data) + + +EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] + +PoolingCompletionRequest = EmbeddingCompletionRequest +PoolingChatRequest = EmbeddingChatRequest +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest] + + +class ScoreRequest(OpenAIBaseModel): + model: Optional[str] = None + text_1: Union[list[str], str] + text_2: Union[list[str], str] + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + + # --8<-- [start:score-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:score-pooling-params] + + # --8<-- [start:score-extra-params] + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # --8<-- [end:score-extra-params] + + def to_pooling_params(self, *, use_cross_encoder: bool = False): + return PoolingParams(use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data) + + +class RerankRequest(OpenAIBaseModel): + model: Optional[str] = None + query: str + documents: list[str] + top_n: int = Field(default_factory=lambda: 0) + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None + + # --8<-- [start:rerank-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:rerank-pooling-params] + + # --8<-- [start:rerank-extra-params] + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # --8<-- [end:rerank-extra-params] + + def to_pooling_params(self, *, use_cross_encoder: bool = False): + return PoolingParams(use_cross_encoder=use_cross_encoder, + additional_data=self.additional_data) + + +class RerankDocument(BaseModel): + text: str + + +class RerankResult(BaseModel): + index: int + document: RerankDocument + relevance_score: float + + +class RerankUsage(BaseModel): + total_tokens: int + + +class RerankResponse(OpenAIBaseModel): + id: str + model: str + usage: RerankUsage + results: list[RerankResult] + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: list[int] = Field(default_factory=list) + token_logprobs: list[Optional[float]] = Field(default_factory=list) + tokens: list[str] = Field(default_factory=list) + top_logprobs: list[Optional[dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[CompletionResponseChoice] + usage: UsageInfo + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[list[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[EmbeddingResponseData] + usage: UsageInfo + + +class PoolingResponseData(OpenAIBaseModel): + index: int + object: str = "pooling" + data: Union[list[list[float]], list[float], str] + + +class PoolingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"pool-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[PoolingResponseData] + usage: UsageInfo + + +class ScoreResponseData(OpenAIBaseModel): + index: int + object: str = "score" + score: float + + +class ScoreResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[ScoreResponseData] + usage: UsageInfo + + +class ClassificationRequest(OpenAIBaseModel): + model: Optional[str] = None + input: Union[list[str], str] + truncate_prompt_tokens: Optional[int] = None + user: Optional[str] = None + + # --8<-- [start:classification-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:classification-pooling-params] + + # --8<-- [start:classification-extra-params] + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # --8<-- [end:classification-extra-params] + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class ClassificationData(OpenAIBaseModel): + index: int + label: Optional[str] + probs: list[float] + num_classes: int + + +class ClassificationResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"classify-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[ClassificationData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=random_tool_call_id) + type: Literal["function"] = "function" + function: FunctionCall + + +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: Optional[str] = None + type: Optional[Literal["function"]] = None + index: int + function: Optional[DeltaFunctionCall] = None + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: list[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + +class ChatMessage(OpenAIBaseModel): + role: str + reasoning_content: Optional[str] = None + content: Optional[str] = None + tool_calls: list[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[list[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + # Workaround: redefine fields name cache so that it's not + # shared with the super class. + field_names: ClassVar[Optional[set[str]]] = None + top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[list[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionResponseChoice] + usage: UsageInfo + prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None + kv_transfer_params: Optional[dict[str, Any]] = Field( + default=None, description="KVTransfer parameters.") + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: list[DeltaToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class TranscriptionResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranscriptionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}") + object: Literal["transcription.chunk"] = "transcription.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[TranscriptionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, + ScoreRequest, RerankRequest] + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: BatchRequestInputBody + + @field_validator('body', mode='plain') + @classmethod + def check_type_for_url(cls, value: Any, info: ValidationInfo): + # Use url to disambiguate models + url: str = info.data["url"] + if url == "/v1/chat/completions": + return ChatCompletionRequest.model_validate(value) + if url == "/v1/embeddings": + return TypeAdapter(EmbeddingRequest).validate_python(value) + if url.endswith("/score"): + return ScoreRequest.model_validate(value) + if url.endswith("/rerank"): + return RerankRequest.model_validate(value) + return TypeAdapter(BatchRequestInputBody).validate_python(value) + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, + ScoreResponse, RerankResponse]] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: Optional[str] = None + prompt: str + + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + return_token_strs: Optional[bool] = Field( + default=False, + description=("If true, also return the token strings " + "corresponding to the token ids."), + ) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: Optional[str] = None + messages: list[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + return_token_strs: Optional[bool] = Field( + default=False, + description=("If true, also return the token strings " + "corresponding to the token ids."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + tools: Optional[list[ChatCompletionToolsParam]] = Field( + default=None, + description=("A list of tools the model may call."), + ) + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: list[int] + token_strs: Optional[list[str]] = None + + +class DetokenizeRequest(OpenAIBaseModel): + model: Optional[str] = None + tokens: list[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str + + +class LoadLoRAAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoRAAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) + + +## Protocols for Audio +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", + "vtt"] + + +class TranscriptionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/audio/createTranscription + + file: UploadFile + """ + The audio file object (not file name) to transcribe, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: Optional[str] = None + """ID of the model to use. + """ + + language: Optional[str] = None + """The language of the input audio. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy and latency. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + ## TODO (varun) : Support if set to 0, certain thresholds are met !! + + timestamp_granularities: list[Literal["word", "segment"]] = Field( + alias="timestamp_granularities[]", default=[]) + """The timestamp granularities to populate for this transcription. + + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: + There is no additional latency for segment timestamps, but generating word + timestamps incurs additional latency. + """ + + stream: Optional[bool] = False + """When set, it will enable output to be streamed in a similar fashion + as the Chat Completion endpoint. + """ + # --8<-- [start:transcription-extra-params] + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + + vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( + default=None, + description=("Additional request parameters with string or " + "numeric values, used by custom extensions."), + ) + # --8<-- [end:transcription-extra-params] + + # --8<-- [start:transcription-sampling-params] + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + + top_p: Optional[float] = None + """Enables nucleus (top-p) sampling, where tokens are selected from the + smallest possible set whose cumulative probability exceeds `p`. + """ + + top_k: Optional[int] = None + """Limits sampling to the `k` most probable tokens at each step.""" + + min_p: Optional[float] = None + """Filters out tokens with a probability lower than `min_p`, ensuring a + minimum likelihood threshold during sampling. + """ + + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + """The seed to use for sampling.""" + + frequency_penalty: Optional[float] = 0.0 + """The frequency penalty to use for sampling.""" + + repetition_penalty: Optional[float] = None + """The repetition penalty to use for sampling.""" + + presence_penalty: Optional[float] = 0.0 + """The presence penalty to use for sampling.""" + # --8<-- [end:transcription-sampling-params] + + # Default sampling parameters for transcription requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_p": 1.0, + "top_k": 0, + "min_p": 0.0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + if (top_k := self.top_k) is None: + top_k = default_sampling_params.get( + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + if (min_p := self.min_p) is None: + min_p = default_sampling_params.get( + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get( + "repetition_penalty", + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY, + extra_args=self.vllm_xargs) + + @model_validator(mode="before") + @classmethod + def validate_transcription_request(cls, data): + if isinstance(data.get("file"), str): + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="Expected 'file' to be a file-like object, not 'str'.", + ) + + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +# Transcription response objects +class TranscriptionResponse(OpenAIBaseModel): + text: str + """The transcribed text.""" + + +class TranscriptionWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranscriptionSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: list[int] + """Array of token IDs for the text content.""" + + +class TranscriptionResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The transcribed text.""" + + segments: Optional[list[TranscriptionSegment]] = None + """Segments of the transcribed text and their corresponding details.""" + + words: Optional[list[TranscriptionWord]] = None + """Extracted words and their corresponding timestamps.""" + + +class TranslationResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranslationStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}") + object: Literal["translation.chunk"] = "translation.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[TranslationResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class TranslationRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/audio/createTranslation + + file: UploadFile + """ + The audio file object (not file name) to translate, in one of these + formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + """ + + model: Optional[str] = None + """ID of the model to use. + """ + + prompt: str = Field(default="") + """An optional text to guide the model's style or continue a previous audio + segment. + + The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) + should match the audio language. + """ + + response_format: AudioResponseFormat = Field(default="json") + """ + The format of the output, in one of these options: `json`, `text`, `srt`, + `verbose_json`, or `vtt`. + """ + + # TODO support additional sampling parameters + # --8<-- [start:translation-sampling-params] + temperature: float = Field(default=0.0) + """The sampling temperature, between 0 and 1. + + Higher values like 0.8 will make the output more random, while lower values + like 0.2 will make it more focused / deterministic. If set to 0, the model + will use [log probability](https://en.wikipedia.org/wiki/Log_probability) + to automatically increase the temperature until certain thresholds are hit. + """ + # --8<-- [end:translation-sampling-params] + + # --8<-- [start:translation-extra-params] + language: Optional[str] = None + """The language of the input audio we translate from. + + Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format + will improve accuracy. + """ + + stream: Optional[bool] = False + """Custom field not present in the original OpenAI definition. When set, + it will enable output to be streamed in a similar fashion as the Chat + Completion endpoint. + """ + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + # --8<-- [end:translation-extra-params] + + # Default sampling parameters for translation requests. + _DEFAULT_SAMPLING_PARAMS: dict = { + "temperature": 0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None) -> SamplingParams: + + max_tokens = default_max_tokens + + if default_sampling_params is None: + default_sampling_params = {} + # Default parameters + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + + return SamplingParams.from_optional(temperature=temperature, + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +# Translation response objects +class TranslationResponse(OpenAIBaseModel): + text: str + """The translated text.""" + + +class TranslationWord(OpenAIBaseModel): + end: float + """End time of the word in seconds.""" + + start: float + """Start time of the word in seconds.""" + + word: str + """The text content of the word.""" + + +class TranslationSegment(OpenAIBaseModel): + id: int + """Unique identifier of the segment.""" + + avg_logprob: float + """Average logprob of the segment. + + If the value is lower than -1, consider the logprobs failed. + """ + + compression_ratio: float + """Compression ratio of the segment. + + If the value is greater than 2.4, consider the compression failed. + """ + + end: float + """End time of the segment in seconds.""" + + no_speech_prob: float + """Probability of no speech in the segment. + + If the value is higher than 1.0 and the `avg_logprob` is below -1, consider + this segment silent. + """ + + seek: int + """Seek offset of the segment.""" + + start: float + """Start time of the segment in seconds.""" + + temperature: float + """Temperature parameter used for generating the segment.""" + + text: str + """Text content of the segment.""" + + tokens: list[int] + """Array of token IDs for the text content.""" + + +class TranslationResponseVerbose(OpenAIBaseModel): + duration: str + """The duration of the input audio.""" + + language: str + """The language of the input audio.""" + + text: str + """The translated text.""" + + segments: Optional[list[TranslationSegment]] = None + """Segments of the translated text and their corresponding details.""" + + words: Optional[list[TranslationWord]] = None + """Extracted words and their corresponding timestamps.""" diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000..e112e2f --- /dev/null +++ b/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,473 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import tempfile +from collections.abc import Awaitable +from http import HTTPStatus +from io import StringIO +from typing import Callable, Optional + +import aiohttp +import torch +from prometheus_client import start_http_server +from tqdm import tqdm + +from vllm.engine.arg_utils import AsyncEngineArgs, optional_type +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger +# yapf: disable +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + EmbeddingResponse, ErrorResponse, + RerankResponse, ScoreResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) +from vllm.entrypoints.openai.serving_score import ServingScores +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + + +def make_arg_parser(parser: FlexibleArgumentParser): + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument( + "--output-tmp-dir", + type=str, + default=None, + help="The directory to store the output file before uploading it " + "to the output URL.", + ) + parser.add_argument("--response-role", + type=optional_type(str), + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=True`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument("--enable-metrics", + action="store_true", + help="Enable Prometheus metrics") + parser.add_argument( + "--url", + type=str, + default="0.0.0.0", + help="URL to the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port number for the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--enable-prompt-tokens-details", + action='store_true', + default=False, + help="If set to True, enable prompt_tokens_details in usage.") + + return parser + + +def parse_args(): + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + return make_arg_parser(parser).parse_args() + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, encoding="utf-8") as f: + return f.read() + + +async def write_local_file(output_path: str, + batch_outputs: list[BatchRequestOutput]) -> None: + """ + Write the responses to a local file. + output_path: The path to write the responses to. + batch_outputs: The list of batch outputs to write. + """ + # We should make this async, but as long as run_batch runs as a + # standalone program, blocking the event loop won't effect performance. + with open(output_path, "w", encoding="utf-8") as f: + for o in batch_outputs: + print(o.model_dump_json(), file=f) + + +async def upload_data(output_url: str, data_or_file: str, + from_file: bool) -> None: + """ + Upload a local file to a URL. + output_url: The URL to upload the file to. + data_or_file: Either the data to upload or the path to the file to upload. + from_file: If True, data_or_file is the path to the file to upload. + """ + # Timeout is a common issue when uploading large files. + # We retry max_retries times before giving up. + max_retries = 5 + # Number of seconds to wait before retrying. + delay = 5 + + for attempt in range(1, max_retries + 1): + try: + # We increase the timeout to 1000 seconds to allow + # for large files (default is 300). + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=1000)) as session: + if from_file: + with open(data_or_file, "rb") as file: + async with session.put(output_url, + data=file) as response: + if response.status != 200: + raise Exception(f"Failed to upload file.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}") + else: + async with session.put(output_url, + data=data_or_file) as response: + if response.status != 200: + raise Exception(f"Failed to upload data.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}") + + except Exception as e: + if attempt < max_retries: + logger.error( + "Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501 + attempt, + e, + delay, + ) + await asyncio.sleep(delay) + else: + raise Exception( + f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501 + ) from e + + +async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], + output_tmp_dir: str) -> None: + """ + Write batch_outputs to a file or upload to a URL. + path_or_url: The path or URL to write batch_outputs to. + batch_outputs: The list of batch outputs to write. + output_tmp_dir: The directory to store the output file before uploading it + to the output URL. + """ + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + if output_tmp_dir is None: + logger.info("Writing outputs to memory buffer") + output_buffer = StringIO() + for o in batch_outputs: + print(o.model_dump_json(), file=output_buffer) + output_buffer.seek(0) + logger.info("Uploading outputs to %s", path_or_url) + await upload_data( + path_or_url, + output_buffer.read().strip().encode("utf-8"), + from_file=False, + ) + else: + # Write responses to a temporary file and then upload it to the URL. + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=output_tmp_dir, + prefix="tmp_batch_output_", + suffix=".jsonl", + ) as f: + logger.info("Writing outputs to temporary local file %s", + f.name) + await write_local_file(f.name, batch_outputs) + logger.info("Uploading outputs to %s", path_or_url) + await upload_data(path_or_url, f.name, from_file=True) + else: + logger.info("Writing outputs to local file %s", path_or_url) + await write_local_file(path_or_url, batch_outputs) + + +def make_error_request_output(request: BatchRequestInput, + error_msg: str) -> BatchRequestOutput: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=HTTPStatus.BAD_REQUEST, + request_id=f"vllm-batch-{random_uuid()}", + ), + error=error_msg, + ) + return batch_output + + +async def make_async_error_request_output( + request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + return make_error_request_output(request, error_msg) + + +async def run_request(serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: + response = await serving_engine_func(request.body) + + if isinstance( + response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, + RerankResponse), + ): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + body=response, request_id=f"vllm-batch-{random_uuid()}"), + error=None, + ) + elif isinstance(response, ErrorResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=response.code, + request_id=f"vllm-batch-{random_uuid()}"), + error=response, + ) + else: + batch_output = make_error_request_output( + request, error_msg="Request must not be sent in stream mode") + + tracker.completed() + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) + + model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + # Create the openai serving objects. + openai_serving_models = OpenAIServingModels( + engine_client=engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + ) + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.runner_type == "generate" else None + openai_serving_embedding = OpenAIServingEmbedding( + engine, + model_config, + openai_serving_models, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + ) if model_config.task == "embed" else None + + enable_serving_reranking = (model_config.task == "classify" and getattr( + model_config.hf_config, "num_labels", 0) == 1) + + openai_serving_scores = (ServingScores( + engine, + model_config, + openai_serving_models, + request_logger=request_logger, + ) if (model_config.task == "embed" or enable_serving_reranking) else None) + + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + + # Submit all requests in the file to the engine "concurrently". + response_futures: list[Awaitable[BatchRequestOutput]] = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + # Skip empty lines. + request_json = request_json.strip() + if not request_json: + continue + + request = BatchRequestInput.model_validate_json(request_json) + + # Determine the type of request and run it. + if request.url == "/v1/chat/completions": + chat_handler_fn = openai_serving_chat.create_chat_completion if \ + openai_serving_chat is not None else None + if chat_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "The model does not support Chat Completions API", + )) + continue + + response_futures.append( + run_request(chat_handler_fn, request, tracker)) + tracker.submitted() + elif request.url == "/v1/embeddings": + embed_handler_fn = openai_serving_embedding.create_embedding if \ + openai_serving_embedding is not None else None + if embed_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Embeddings API", + )) + continue + + response_futures.append( + run_request(embed_handler_fn, request, tracker)) + tracker.submitted() + elif request.url.endswith("/score"): + score_handler_fn = openai_serving_scores.create_score if \ + openai_serving_scores is not None else None + if score_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Scores API", + )) + continue + + response_futures.append( + run_request(score_handler_fn, request, tracker)) + tracker.submitted() + elif request.url.endswith("/rerank"): + rerank_handler_fn = openai_serving_scores.do_rerank if \ + openai_serving_scores is not None else None + if rerank_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Rerank API", + )) + continue + + response_futures.append( + run_request(rerank_handler_fn, request, tracker)) + tracker.submitted() + else: + response_futures.append( + make_async_error_request_output( + request, + error_msg=f"URL {request.url} was used. " + "Supported endpoints: /v1/chat/completions, /v1/embeddings," + " /score, /rerank ." + "See vllm/entrypoints/openai/api_server.py for supported " + "score/rerank versions.", + )) + + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) + + await write_file(args.output_file, responses, args.output_tmp_dir) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM batch processing API version %s", VLLM_VERSION) + logger.info("args: %s", args) + + # Start the Prometheus metrics server. LLMEngine uses the Prometheus client + # to publish metrics at the /metrics endpoint. + if args.enable_metrics: + logger.info("Prometheus metrics enabled") + start_http_server(port=args.port, addr=args.url) + else: + logger.info("Prometheus metrics disabled") + + asyncio.run(main(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000..a802fbc --- /dev/null +++ b/vllm/entrypoints/openai/serving_chat.py @@ -0,0 +1,1258 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +import time +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Callable, Final, Optional, Union + +import jinja2 +import partial_json_parser +import regex as re +from fastapi import Request +from pydantic import TypeAdapter + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + ConversationMessage, + random_tool_call_id) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition, + PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + clamp_prompt_logprobs) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall) +from vllm.entrypoints.utils import get_max_tokens +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params) + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + response_role: str, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + reasoning_parser: str = "", + enable_auto_tools: bool = False, + expand_tools_even_if_tool_choice_none: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage) + + self.response_role = response_role + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.reasoning_parser: Optional[Callable[[AnyTokenizer], + ReasoningParser]] = None + if reasoning_parser: + try: + self.reasoning_parser = ( + ReasoningParserManager.get_reasoning_parser( + reasoning_parser)) + assert self.reasoning_parser is not None + except Exception as e: + raise TypeError( + f"{reasoning_parser=} has not been registered") from e + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + try: + if (tool_parser == "pythonic" and + model_config.model.startswith("meta-llama/Llama-3.2")): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic" + " tool calls") + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: + raise TypeError("Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e + self.expand_tools_even_if_tool_choice_none = ( + expand_tools_even_if_tool_choice_none) + + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_force_include_usage = enable_force_include_usage + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + source = self.model_config.generation_config + source = "model" if source == "auto" else source + logger.info("Using default chat sampling params from %s: %s", + source, self.default_sampling_params) + + async def create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: + """ + Chat Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI + Chat Completion API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + model_name = self._get_model_name(request.model, lora_request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + tool_parser = self.tool_parser + + if isinstance(tokenizer, MistralTokenizer): + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` + maybe_serialize_tool_calls(request) + truncate_tool_call_ids(request) + validate_request_params(request) + + if (request.tool_choice == "auto" and + not (self.enable_auto_tools and tool_parser is not None) + and not isinstance(tokenizer, MistralTokenizer)): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set" + ) + + if request.tools is None: + tool_dicts = None + elif (request.tool_choice == "none" + and not self.expand_tools_even_if_tool_choice_none): + if len(request.tools) > 0: + logger.warning_once( + "Tools are specified but tool_choice is set to 'none' " + "and --expand-tools-even-if-tool-choice-none is not " + "enabled. Tool definitions will be excluded from the " + "prompt. This behavior will change in vLLM v0.10 where " + "tool definitions will be included by default even " + "with tool_choice='none'. To adopt the new behavior " + "now, use --expand-tools-even-if-tool-choice-none. " + "To suppress this warning, either remove tools from " + "the request or set tool_choice to a different value.") + tool_dicts = None + else: + tool_dicts = [tool.model_dump() for tool in request.tools] + + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError, RuntimeError, + jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(f"{e} {e.__cause__}") + + request_id = "chatcmpl-" \ + f"{self._base_request_id(raw_request, request.request_id)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + + if self.default_sampling_params is None: + self.default_sampling_params = {} + + max_tokens = get_max_tokens( + max_model_len=self.max_model_len, + request=request, + input_length=len(engine_prompt["prompt_token_ids"]), + default_sampling_params=self.default_sampling_params) + + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + max_tokens, self.default_sampling_params) + else: + sampling_params = request.to_sampling_params( + max_tokens, self.model_config.logits_processor_pattern, + self.default_sampling_params) + + self._log_inputs(request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + request_id=request_id, + params=sampling_params, + lora_request=lora_request, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + enable_force_include_usage=self.enable_force_include_usage) + + try: + return await self.chat_completion_full_generator( + request, result_generator, request_id, model_name, + conversation, tokenizer, request_metadata) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + return request.messages[-1]["role"] + + @staticmethod + def _bracket_level(s: str, opening='{', closing='}') -> int: + """ + Calculate the current level of nested brackets in a given string. + """ + level = 0 + for char in s: + if char == opening: + level += 1 + elif char == closing: + level -= 1 + return level + + @staticmethod + def _filter_delta_text(delta_text: str, + previous_text: str) -> tuple[str, bool]: + # remove last '},' of the tool definition stemming from the + # "name"/"parameters" outer object or closing ']' of the tool list + # count occurrences of opening and closing curly braces and + # once level 0 is reached stop outputting text + # if 0 is reached while parsing the delta_text we know the current + # tool will finish in this current iteration + bracket_level = OpenAIServingChat._bracket_level(previous_text) + updated_delta, passed_zero = "", False + for c in delta_text: + if c == '{': + bracket_level += 1 + passed_zero = bracket_level == 0 + elif c == '}': + bracket_level -= 1 + passed_zero = bracket_level == 0 + + if bracket_level != 0: + updated_delta += c + else: + # if a comma is reached at level 0 we can stop + if c == ',': + break + return updated_delta, passed_zero + + def extract_tool_call_required_streaming( + self, + previous_text: str, + current_text: Optional[str], + delta_text: str, + function_name_returned: bool, + ) -> tuple[Optional[DeltaMessage], bool]: + if current_text is None or current_text == "": + # if the current text is empty, we cannot parse it + return None, function_name_returned + try: + obj = partial_json_parser.loads(current_text) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + obj = None + + # check if the current text is a valid array + # containing a partial tool calling object + # if not repeat + if obj is None or not isinstance(obj, list) or not len(obj) > 0: + function_name_returned = False + delta_message = None + else: + _, finishes_previous_tool = OpenAIServingChat._filter_delta_text( + delta_text, previous_text) + # take the last tool call from the generated list + current_tool_call = obj[-1] + + # once parameters have been generated the name is complete as well + if not finishes_previous_tool and ("name" not in current_tool_call + or "parameters" + not in current_tool_call): + function_name_returned = False + delta_message = None + else: + if not function_name_returned: + # get partly generated arguments from the latest tool call + param_match = re.search(r'.*"parameters":\s*(.*)', + current_text) + arguments = param_match.group(1) if param_match else "" + arguments, _ = OpenAIServingChat._filter_delta_text( + arguments, previous_text) + + # if this iteration finishes a previous tool call but a + # new incomplete tool is already generated, take the + # previous from the list + if (finishes_previous_tool + and "parameters" not in current_tool_call): + current_tool_call = obj[-2] + + function_name_returned = True + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(id=random_tool_call_id(), + function=DeltaFunctionCall( + name=current_tool_call["name"], + arguments=arguments), + index=len(obj) - 1, + type="function") + ]) + + else: + delta_text, _ = OpenAIServingChat._filter_delta_text( + delta_text, previous_text) + + if delta_text != "": + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall( + function=DeltaFunctionCall( + # OpenAI API returns None + # instead of name every time + name=None, + arguments=delta_text), + index=len(obj) - 1) + ]) + else: + delta_message = None + + return delta_message, function_name_returned + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + enable_force_include_usage: bool, + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + num_cached_tokens = None + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + all_previous_token_ids: Optional[list[list[int]]] + function_name_returned = [False] * num_choices + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or self.reasoning_parser: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + # For reasoning parser and tool call all enabled + added_content_delta_arr = [False] * num_choices + reasoning_end_arr = [False] * num_choices + elif request.tool_choice == "required": + previous_texts = [""] * num_choices + all_previous_token_ids = None + else: + previous_texts, all_previous_token_ids = None, None + + try: + if self.reasoning_parser: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: list[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) + ] * num_choices + else: + tool_parsers = [None] * num_choices + except Exception as e: + logger.exception("Error in tool parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage \ + or enable_force_include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + num_cached_tokens = res.num_cached_tokens + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + logprobs=None, + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + return_as_token_id=request. + return_tokens_as_token_ids, + ) + else: + logprobs = None + + delta_text = output.text + + if not delta_text and not output.token_ids and \ + not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + delta_message: Optional[DeltaMessage] + + # just update previous_texts and previous_token_ids + if tool_choice_auto or self.reasoning_parser: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + + # handle streaming deltas for tools with named tool_choice + if tool_choice_function_name: + if (self.reasoning_parser + and not reasoning_parser.is_reasoning_end( + previous_token_ids)): + assert reasoning_parser is not None + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # When encountering think end id in delta_token_ids, + # process the `content`. Only keep 'content', + # remove 'reasoning_content' + if reasoning_parser.is_reasoning_end( + list(output.token_ids)): + if delta_message and delta_message.content: + # This need to be added to next `delta_text` + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + else: + # Just to add remaining `content` + if self.reasoning_parser: + delta_text = previous_text + delta_text + current_text = "" + + if function_name_returned[i]: + delta_tool_call = DeltaToolCall( + function=DeltaFunctionCall( + arguments=delta_text), + index=i) + else: + delta_tool_call = DeltaToolCall( + id=random_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + function_name_returned[i] = True + + delta_message = DeltaMessage(tool_calls=[ + delta_tool_call, + ]) + + elif request.tool_choice == "required": + assert previous_texts is not None + previous_text = previous_texts[i] + current_text = previous_text + delta_text + fn_name_returned = function_name_returned[i] + + if self.reasoning_parser: + _, content = \ + reasoning_parser.extract_reasoning_content( + current_text, + request + ) + else: + content = current_text + delta_message, function_name_returned[i] = ( + self.extract_tool_call_required_streaming( + previous_text=previous_text, + current_text=content, + delta_text=delta_text, + function_name_returned=fn_name_returned)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + + # handle streaming deltas for tools with "auto" tool choice + # and reasoning parser + elif tool_choice_auto and self.reasoning_parser: + assert tool_parser is not None + assert reasoning_parser is not None + assert added_content_delta_arr is not None + assert reasoning_end_arr is not None + if not reasoning_end_arr[i]: + delta_message = ( + reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # When encountering think end id in prompt_token_ids + # i.e {"enable_thinking": False}, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning_content'. + if res.prompt_token_ids and \ + reasoning_parser.is_reasoning_end( + list(res.prompt_token_ids)): + reasoning_end_arr[i] = True + current_token_ids = list(output.token_ids) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + # When encountering think end id in delta_token_ids, + # set reasoning status to end. + # Remove the text and token ids related + # to 'reasoning_content'. + if reasoning_parser.is_reasoning_end( + list(output.token_ids)): + reasoning_end_arr[i] = True + current_token_ids = \ + reasoning_parser.extract_content_ids( + list(output.token_ids)) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + # handle tool calls only after reasoning is done, + else: + delta_token_ids = list(output.token_ids) + # First time to tool call, + # add the remaining text and token ids + # to delta from previous + if not added_content_delta_arr[i]: + added_content_delta_arr[i] = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request)) + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request)) + # when only reasoning + elif self.reasoning_parser: + delta_message = (reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # update the previous values for the next iteration + if tool_choice_auto or self.reasoning_parser: + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + auto_tools_called = False + if tool_parser: + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + latest_delta_len = 0 + if ((isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall)) and isinstance( + delta_message.tool_calls[0].function. + arguments, str)): + latest_delta_len = len( + delta_message.tool_calls[0].function. + arguments) + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {}), + ensure_ascii=False) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + if (latest_delta_len > 0): + actual_call = actual_call[:-latest_delta_len] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + + # Send the finish response for each request.n only once + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason + if not auto_tools_called else "tool_calls", + stop_reason=output.stop_reason) + + finish_reason_sent[i] = True + + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if include_usage: + completion_tokens = sum(previous_num_tokens) + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation: list[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + + choices: list[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + auto_tools_called = False + + if self.reasoning_parser: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + # If the reasoning parser is enabled, + # tool calls are extracted exclusively from the content. + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content( + output.text, request=request)) + else: + reasoning_content = None + content = output.text + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools or not self.tool_parser) and \ + (not isinstance(request.tool_choice, + ChatCompletionNamedToolChoiceParam + ) and request.tool_choice != "required"): + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content="", + tool_calls=[ + tool_call_class(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=content)) + ]) + + elif request.tool_choice and request.tool_choice == "required": + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall + + # the fields of FunctionDefinition are a superset of the + # tool call outputs and can be used for parsing + assert content is not None + tool_calls = TypeAdapter( + list[FunctionDefinition]).validate_json(content) + message = ChatMessage( + role=role, + content="", + tool_calls=[ + tool_call_class(function=FunctionCall( + name=tool_call.name, + arguments=json.dumps(tool_call.parameters, + ensure_ascii=False))) + for tool_call in tool_calls + ]) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in tool parser creation.") + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + content if content is not None else "", request=request) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") + message = ChatMessage(role=role, + reasoning_content=reasoning_content, + content=content) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if auto_tools_called else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason) + choices.append(choice_data) + + if request.echo: + last_msg_content: Union[str, list[dict[str, str]]] = "" + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + if isinstance(last_msg_content, list): + last_msg_content = "\n".join(msg['text'] + for msg in last_msg_content) + + for choice in choices: + full_message = last_msg_content + (choice.message.content + or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + + request_metadata.final_usage_info = usage + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), + kv_transfer_params=final_res.kv_transfer_params, + ) + + return response + + def _get_top_logprobs( + self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer, + should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb(token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=should_return_as_token_id)), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + token.encode("utf-8", errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], + tokenizer: AnyTokenizer, + num_output_top_logprobs: Optional[int] = None, + return_as_token_id: Optional[bool] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: list[ChatCompletionLogProbsContent] = [] + + should_return_as_token_id = return_as_token_id if \ + return_as_token_id is not None else self.return_tokens_as_token_ids + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None or step_top_logprobs.get( + token_id) is None: + token = tokenizer.decode(token_id) + if should_return_as_token_id: + token = f"token_id:{token_id}" + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + bytes=list(token.encode("utf-8", errors="replace")), + )) + else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self._get_decoded_token( + step_token, + token_id, + tokenizer, + should_return_as_token_id, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs, + tokenizer, should_return_as_token_id), + )) + + return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + ) diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py new file mode 100644 index 0000000..3ac4f01 --- /dev/null +++ b/vllm/entrypoints/openai/serving_classification.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from http import HTTPStatus +from typing import Optional, Union, cast + +import numpy as np +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ClassificationData, + ClassificationRequest, + ClassificationResponse, + ErrorResponse, UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, + OpenAIServing, + ServeContext) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import ClassificationOutput, PoolingRequestOutput + +logger = init_logger(__name__) + + +class ClassificationMixin(OpenAIServing): + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """ + Process classification inputs: tokenize text, resolve adapters, + and prepare model-specific inputs. + """ + ctx = cast(ClassificationServeContext, ctx) + if isinstance(ctx.request.input, str) and not ctx.request.input: + return self.create_error_response( + "Input cannot be empty for classification", + status_code=HTTPStatus.BAD_REQUEST, + ) + + if isinstance(ctx.request.input, list) and len(ctx.request.input) == 0: + return None + + try: + ( + ctx.lora_request, + ctx.prompt_adapter_request, + ) = self._maybe_get_adapters(ctx.request) + + ctx.tokenizer = await self.engine_client.get_tokenizer( + ctx.lora_request) + + if ctx.prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported for classification models" + ) + + ( + ctx.request_prompts, + ctx.engine_prompts, + ) = await self._preprocess_completion( + ctx.request, + ctx.tokenizer, + ctx.request.input, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, + ) + + return None + + except (ValueError, TypeError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[ClassificationResponse, ErrorResponse]: + """ + Convert model outputs to a formatted classification response + with probabilities and labels. + """ + ctx = cast(ClassificationServeContext, ctx) + items: list[ClassificationData] = [] + num_prompt_tokens = 0 + + final_res_batch_checked = cast(list[PoolingRequestOutput], + ctx.final_res_batch) + + for idx, final_res in enumerate(final_res_batch_checked): + classify_res = ClassificationOutput.from_base(final_res.outputs) + + probs = classify_res.probs + predicted_index = int(np.argmax(probs)) + label = getattr(self.model_config.hf_config, "id2label", + {}).get(predicted_index) + + item = ClassificationData( + index=idx, + label=label, + probs=probs, + num_classes=len(probs), + ) + + items.append(item) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ClassificationResponse( + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, + data=items, + usage=usage, + ) + + +class ServingClassification(ClassificationMixin): + request_id_prefix = "classify" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + ) + + async def create_classify( + self, + request: ClassificationRequest, + raw_request: Request, + ) -> Union[ClassificationResponse, ErrorResponse]: + model_name = self._get_model_name(request.model) + request_id = (f"{self.request_id_prefix}-" + f"{self._base_request_id(raw_request)}") + + ctx = ClassificationServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) + + return await super().handle(ctx) # type: ignore diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py new file mode 100644 index 0000000..6c9c29b --- /dev/null +++ b/vllm/entrypoints/openai/serving_completion.py @@ -0,0 +1,618 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Optional, Union, cast + +import jinja2 +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + RequestResponseMetadata, + UsageInfo) +from vllm.entrypoints.openai.serving_engine import ( + EmbedsPrompt as ServingEngineEmbedsPrompt) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + TextTokensPrompt, + clamp_prompt_logprobs, + is_text_tokens_prompt) +# yapf: enable +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.utils import get_max_tokens +from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, + is_tokens_prompt) +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + enable_force_include_usage: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage) + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + source = self.model_config.generation_config + source = "model" if source == "auto" else source + logger.info("Using default completion sampling params from %s: %s", + source, self.default_sampling_params) + + async def create_completion( + self, + request: CompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following feature: + - suffix (the language models we currently support do not support + suffix) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Return error for unsupported features. + if request.suffix is not None: + return self.create_error_response( + "suffix is not currently supported") + + if request.echo and request.prompt_embeds is not None: + return self.create_error_response( + "Echo is unsupported with prompt embeds.") + + request_id = f"cmpl-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + request_prompts, engine_prompts = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + except TypeError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + except RuntimeError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + except jinja2.TemplateError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + # Mypy does not infer that engine_prompt will have only one of + # "prompt_token_ids" or "prompt_embeds" defined, and both of + # these as Union[object, the expected type], where it infers + # object if engine_prompt is a subclass of one of the + # typeddicts that defines both keys. Worse, because of + # https://github.com/python/mypy/issues/8586, mypy does not + # infer the type of engine_prompt correctly because of the + # enumerate. So we need an unnecessary cast here. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) + if is_embeds_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_embeds"]) + elif is_tokens_prompt(engine_prompt): + input_length = len(engine_prompt["prompt_token_ids"]) + else: + assert_never(engine_prompt) + + if self.default_sampling_params is None: + self.default_sampling_params = {} + + max_tokens = get_max_tokens( + max_model_len=self.max_model_len, + request=request, + input_length=input_length, + default_sampling_params=self.default_sampling_params) + + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + max_tokens, self.default_sampling_params) + else: + sampling_params = request.to_sampling_params( + max_tokens, self.model_config.logits_processor_pattern, + self.default_sampling_params) + + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + # Mypy inconsistently requires this second cast in different + # environments. It shouldn't be necessary (redundant from above) + # but pre-commit in CI fails without it. + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], + engine_prompt) + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + request_id=request_id, + params=sampling_params, + lora_request=lora_request, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id_item, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + model_name = self._get_model_name(request.model, lora_request) + num_prompts = len(engine_prompts) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. Noting that best_of is only supported in V0. In addition, + # we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + # Streaming response + if stream: + return self.completion_stream_generator( + request, + request_prompts, + result_generator, + request_id, + created_time, + model_name, + num_prompts=num_prompts, + tokenizer=tokenizer, + request_metadata=request_metadata, + enable_force_include_usage=self.enable_force_include_usage) + + # Non-streaming response + final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + for i, final_res in enumerate(final_res_batch): + assert final_res is not None + + # The output should contain the input text + # We did not pass it into vLLM engine to avoid being redundant + # with the inputs token IDs + if final_res.prompt is None: + request_prompt = request_prompts[i] + if is_text_tokens_prompt(request_prompt): + final_res.prompt = request_prompt["prompt"] + else: + final_res.prompt = None + + final_res_batch_checked = cast(list[RequestOutput], + final_res_batch) + + response = self.request_output_to_completion_response( + final_res_batch_checked, + request, + request_id, + created_time, + model_name, + tokenizer, + request_metadata, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + if request.stream: + response_json = response.model_dump_json() + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response + + async def completion_stream_generator( + self, + request: CompletionRequest, + request_prompts: list[Union[TextTokensPrompt, + ServingEngineEmbedsPrompt]], + result_generator: AsyncIterator[tuple[int, RequestOutput]], + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + enable_force_include_usage: bool, + ) -> AsyncGenerator[str, None]: + num_choices = 1 if request.n is None else request.n + previous_text_lens = [0] * num_choices * num_prompts + previous_num_tokens = [0] * num_choices * num_prompts + has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts + + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage or \ + enable_force_include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + + try: + async for prompt_idx, res in result_generator: + prompt_token_ids = res.prompt_token_ids + prompt_logprobs = res.prompt_logprobs + + if res.prompt is not None: + prompt_text = res.prompt + else: + request_prompt = request_prompts[prompt_idx] + if is_text_tokens_prompt(request_prompt): + prompt_text = request_prompt["prompt"] + else: + prompt_text = None + + # Prompt details are excluded from later streamed outputs + if prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(prompt_token_ids) + + delta_token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[dict[ + int, Logprob]]]] + + for output in res.outputs: + i = output.index + prompt_idx * num_choices + + assert request.max_tokens is not None + if request.echo and not has_echoed[i]: + assert prompt_token_ids is not None + assert prompt_text is not None + if request.max_tokens == 0: + # only return the prompt + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + else: + # echo the prompt and first token + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *(prompt_logprobs or []), + *(output.logprobs or []), + ] + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs + + if not delta_text and not delta_token_ids \ + and not previous_num_tokens[i]: + # Chunked prefill case, don't return empty chunks + continue + + if request.logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_completion_logprobs( + token_ids=delta_token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.logprobs, + tokenizer=tokenizer, + initial_text_offset=previous_text_lens[i], + return_as_token_id=request. + return_tokens_as_token_ids, + ) + else: + logprobs = None + + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) + finish_reason = output.finish_reason + stop_reason = output.stop_reason + + chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + stop_reason=stop_reason, + ) + ]) + if include_continuous_usage: + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + response_json = chunk.model_dump_json(exclude_unset=False) + yield f"data: {response_json}\n\n" + + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + + if include_usage: + final_usage_chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[], + usage=final_usage_info, + ) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=False, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = final_usage_info + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + + def request_output_to_completion_response( + self, + final_res_batch: list[RequestOutput], + request: CompletionRequest, + request_id: str, + created_time: int, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> CompletionResponse: + choices: list[CompletionResponseChoice] = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + + for final_res in final_res_batch: + prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None + prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) + prompt_text = final_res.prompt + + token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[dict[int, + Logprob]]]] + + for output in final_res.outputs: + assert request.max_tokens is not None + if request.echo: + assert prompt_text is not None + if request.max_tokens == 0: + token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + output_text = prompt_text + else: + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert prompt_logprobs is not None + assert output.logprobs is not None + out_logprobs = [ + *prompt_logprobs, + *output.logprobs, + ] + + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + out_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_completion_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.logprobs, + return_as_token_id=request.return_tokens_as_token_ids, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + prompt_logprobs=final_res.prompt_logprobs, + ) + choices.append(choice_data) + + num_generated_tokens += len(output.token_ids) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + request_metadata.final_usage_info = usage + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + kv_transfer_params=final_res_batch[0].kv_transfer_params) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[dict[int, Logprob]]], + num_output_top_logprobs: int, + tokenizer: AnyTokenizer, + initial_text_offset: int = 0, + return_as_token_id: Optional[bool] = None, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: list[int] = [] + out_token_logprobs: list[Optional[float]] = [] + out_tokens: list[str] = [] + out_top_logprobs: list[Optional[dict[str, float]]] = [] + + last_token_len = 0 + + should_return_as_token_id = return_as_token_id if \ + return_as_token_id is not None else self.return_tokens_as_token_ids + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if should_return_as_token_id: + token = f"token_id:{token_id}" + + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + step_token = step_top_logprobs[token_id] + + token = self._get_decoded_token( + step_token, + token_id, + tokenizer, + return_as_token_id=should_return_as_token_id, + ) + token_logprob = max(step_token.logprob, -9999.0) + + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token(top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=should_return_as_token_id): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000..e87decf --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +from typing import Final, Literal, Optional, Union, cast + +import numpy as np +from fastapi import Request +from typing_extensions import assert_never, override + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, + OpenAIServing, + ServeContext) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, + PoolingRequestOutput) + +logger = init_logger(__name__) + + +def _get_embedding( + output: EmbeddingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[list[float], str]: + if encoding_format == "float": + return output.embedding + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + embedding_bytes = np.array(output.embedding, dtype="float32").tobytes() + return base64.b64encode(embedding_bytes).decode("utf-8") + + assert_never(encoding_format) + + +class EmbeddingMixin(OpenAIServing): + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + ctx = cast(EmbeddingServeContext, ctx) + try: + ( + ctx.lora_request, + ctx.prompt_adapter_request, + ) = self._maybe_get_adapters(ctx.request) + + tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request + ) + + if ctx.prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(ctx.request, EmbeddingChatRequest): + ( + _, + ctx.request_prompts, + ctx.engine_prompts, + ) = await self._preprocess_chat( + ctx.request, + tokenizer, + ctx.request.messages, + chat_template=ctx.request.chat_template + or ctx.chat_template, + chat_template_content_format=ctx. + chat_template_content_format, + # In embedding requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, + truncate_prompt_tokens=ctx.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, + ) + else: + (ctx.request_prompts, + ctx.engine_prompts) = await self._preprocess_completion( + ctx.request, + tokenizer, + ctx.request.input, + truncate_prompt_tokens=ctx.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, + ) + return None + except (ValueError, TypeError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[EmbeddingResponse, ErrorResponse]: + items: list[EmbeddingResponseData] = [] + num_prompt_tokens = 0 + + final_res_batch_checked = cast(list[PoolingRequestOutput], + ctx.final_res_batch) + + for idx, final_res in enumerate(final_res_batch_checked): + embedding_res = EmbeddingRequestOutput.from_base(final_res) + + item = EmbeddingResponseData( + index=idx, + embedding=_get_embedding(embedding_res.outputs, + ctx.request.encoding_format), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=ctx.request_id, + created=ctx.created_time, + model=ctx.model_name, + data=items, + usage=usage, + ) + + +class OpenAIServingEmbedding(EmbeddingMixin): + request_id_prefix = "embd" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """ + Embedding API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + model_name = self._get_model_name(request.model) + request_id = (f"{self.request_id_prefix}-" + f"{self._base_request_id(raw_request)}") + + ctx = EmbeddingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + + return await super().handle(ctx) # type: ignore + + @override + def _validate_request( + self, + ctx: ServeContext[EmbeddingRequest], + ) -> Optional[ErrorResponse]: + if error := super()._validate_request(ctx): + return error + + ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens + + pooling_params = ctx.request.to_pooling_params() + + try: + pooling_params.verify(self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + return None diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py new file mode 100644 index 0000000..9a5adda --- /dev/null +++ b/vllm/entrypoints/openai/serving_engine.py @@ -0,0 +1,997 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import io +import json +import sys +import time +from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, + Sequence) +from concurrent.futures.thread import ThreadPoolExecutor +from http import HTTPStatus +from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, + TypeVar, Union, cast, overload) + +import torch +from fastapi import Request +from pydantic import BaseModel, ConfigDict, Field +from starlette.datastructures import Headers +from typing_extensions import TypeIs + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + PoolingResponse, RerankRequest, + ScoreRequest, ScoreResponse, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.tool_parsers import ToolParser +# yapf: enable +from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TokensPrompt as EngineTokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin + MultiModalDataDict) +from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob, PromptLogprobs +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +from vllm.transformers_utils.tokenizers import CPM9GTokenizer +from vllm.utils import (is_list_of, make_async, merge_async_iterators, + random_uuid) + +logger = init_logger(__name__) + +CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, + EmbeddingCompletionRequest, RerankRequest, + ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest] + +ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, + TokenizeChatRequest] +SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest] + +AnyResponse = Union[ + CompletionResponse, + ChatCompletionResponse, + EmbeddingResponse, + TranscriptionResponse, + TokenizeResponse, + PoolingResponse, + ClassificationResponse, + ScoreResponse, +] + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: list[int] + + +class EmbedsPrompt(TypedDict): + prompt_embeds: torch.Tensor + + +RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt] + + +def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + +RequestT = TypeVar("RequestT", bound=AnyRequest) + + +class RequestProcessingMixin(BaseModel): + """ + Mixin for request processing, + handling prompt preparation and engine input. + """ + request_prompts: Optional[Sequence[RequestPrompt]] = [] + engine_prompts: Optional[Union[list[EngineTokensPrompt], + list[EngineEmbedsPrompt]]] = [] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ResponseGenerationMixin(BaseModel): + """ + Mixin for response generation, + managing result generators and final batch results. + """ + result_generator: Optional[AsyncGenerator[tuple[int, Union[ + RequestOutput, PoolingRequestOutput]], None]] = None + final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( + default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, + Generic[RequestT]): + # Shared across all requests + request: RequestT + raw_request: Optional[Request] = None + model_name: str + request_id: str + created_time: int = Field(default_factory=lambda: int(time.time())) + lora_request: Optional[LoRARequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + # Shared across most requests + tokenizer: Optional[AnyTokenizer] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # `protected_namespaces` resolves Pydantic v2's warning + # on conflict with protected namespace "model_" + model_config = ConfigDict( + protected_namespaces=(), + arbitrary_types_allowed=True, + ) + + +ClassificationServeContext = ServeContext[ClassificationRequest] + + +class EmbeddingServeContext(ServeContext[EmbeddingRequest]): + chat_template: Optional[str] = None + chat_template_content_format: ChatTemplateContentFormatOption + + +# Used to resolve the Pydantic error related to +# forward reference of MultiModalDataDict in TokensPrompt +RequestProcessingMixin.model_rebuild() +ServeContext.model_rebuild() +ClassificationServeContext.model_rebuild() +EmbeddingServeContext.model_rebuild() + + +class OpenAIServing: + request_id_prefix: ClassVar[str] = """ + A short string prepended to every request’s ID (e.g. "embd", "classify") + so you can easily tell “this ID came from Embedding vs Classification.” + """ + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + enable_force_include_usage: bool = False, + ): + super().__init__() + + self.engine_client = engine_client + self.model_config = model_config + self.max_model_len = model_config.max_model_len + self.tokenizer_mode = model_config.tokenizer_mode + + if model_config.tokenizer_mode == "cpm": + self.tokenizer = CPM9GTokenizer(model_config.model, trust_remote_code=True) + + self.models = models + + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.enable_force_include_usage = enable_force_include_usage + + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self._tokenize_prompt_input_async = make_async( + self._tokenize_prompt_input, executor=self._tokenizer_executor) + self._tokenize_prompt_input_or_inputs_async = make_async( + self._tokenize_prompt_input_or_inputs, + executor=self._tokenizer_executor) + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """ + Default preprocessing hook. Subclasses may override + to prepare `ctx` (classification, embedding, etc.). + """ + return None + + def _build_response( + self, + ctx: ServeContext, + ) -> Union[AnyResponse, ErrorResponse]: + """ + Default response builder. Subclass may override this method + to return the appropriate response object. + """ + return self.create_error_response("unimplemented endpoint") + + async def handle( + self, + ctx: ServeContext, + ) -> Union[AnyResponse, ErrorResponse]: + generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None] + generation = self._pipeline(ctx) + + async for response in generation: + return response + + return self.create_error_response("No response yielded from pipeline") + + async def _pipeline( + self, + ctx: ServeContext, + ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]: + """Execute the request processing pipeline yielding responses.""" + if error := await self._check_model(ctx.request): + yield error + if error := self._validate_request(ctx): + yield error + + preprocess_ret = await self._preprocess(ctx) + if isinstance(preprocess_ret, ErrorResponse): + yield preprocess_ret + + generators_ret = await self._prepare_generators(ctx) + if isinstance(generators_ret, ErrorResponse): + yield generators_ret + + collect_ret = await self._collect_batch(ctx) + if isinstance(collect_ret, ErrorResponse): + yield collect_ret + + yield self._build_response(ctx) + + def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", + None) + + if truncate_prompt_tokens is not None: + if truncate_prompt_tokens <= self.max_model_len: + ctx.truncate_prompt_tokens = truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + return None + + async def _prepare_generators( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Schedule the request and get the result generator.""" + generators: list[AsyncGenerator[Union[RequestOutput, + PoolingRequestOutput], + None]] = [] + + try: + trace_headers = (None if ctx.raw_request is None else await + self._get_trace_headers(ctx.raw_request.headers)) + + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + pooling_params = ctx.request.to_pooling_params() + + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" + + if ctx.request_prompts is None: + return self.create_error_response( + "Request prompts not available") + + self._log_inputs( + request_id_item, + ctx.request_prompts[i], + params=pooling_params, + lora_request=ctx.lora_request, + prompt_adapter_request=ctx.prompt_adapter_request) + + # Mypy has an existing bug related to inferring the variance of + # TypedDicts with `builtins.enumerate`: + # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 + engine_prompt = cast( + Union[EngineTokensPrompt, EngineEmbedsPrompt], + engine_prompt) + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + ctx.result_generator = merge_async_iterators(*generators) + + return None + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _collect_batch( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + """Collect batch results from the result generator.""" + try: + if ctx.engine_prompts is None: + return self.create_error_response( + "Engine prompts not available") + + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[Optional[Union[RequestOutput, + PoolingRequestOutput]]] + final_res_batch = [None] * num_prompts + + if ctx.result_generator is None: + return self.create_error_response( + "Result generator not available") + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + return self.create_error_response( + "Failed to generate results for all prompts") + + ctx.final_res_batch = [ + res for res in final_res_batch if res is not None + ] + + return None + + except Exception as e: + return self.create_error_response(str(e)) + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + json_str = json.dumps({ + "error": + self.create_error_response(message=message, + err_type=err_type, + status_code=status_code).model_dump() + }) + return json_str + + async def _check_model( + self, + request: AnyRequest, + ) -> Optional[ErrorResponse]: + + error_response = None + + if self._is_model_supported(request.model): + return None + if request.model in [ + lora.lora_name for lora in self.models.lora_requests + ]: + return None + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and ( + load_result := await self.models.resolve_lora(request.model)): + if isinstance(load_result, LoRARequest): + return None + if isinstance(load_result, ErrorResponse) and \ + load_result.code == HTTPStatus.BAD_REQUEST.value: + error_response = load_result + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.models.prompt_adapter_requests + ]: + return None + + return error_response or self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + def _maybe_get_adapters( + self, request: AnyRequest + ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ + None, PromptAdapterRequest]]: + if self._is_model_supported(request.model): + return None, None + for lora in self.models.lora_requests: + if request.model == lora.lora_name: + return lora, None + for prompt_adapter in self.models.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return None, prompt_adapter + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _normalize_prompt_text_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt: str, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], + add_special_tokens: bool, + ) -> TextTokensPrompt: + if (self.model_config.encoder_config is not None + and self.model_config.encoder_config.get( + "do_lower_case", False)): + prompt = prompt.lower() + + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + elif truncate_prompt_tokens < 0: + # Negative means we cap at the model's max length + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=self.max_model_len) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + if self.tokenizer_mode == "cpm": + input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(prompt) + else: + input_ids = encoded.input_ids + + input_text = prompt + + return self._validate_input(request, input_ids, input_text) + + def _normalize_prompt_tokens_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_ids: list[int], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + input_ids = prompt_ids + elif truncate_prompt_tokens < 0: + input_ids = prompt_ids[-self.max_model_len:] + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + + input_text = tokenizer.decode(input_ids) if self.tokenizer_mode != "cpm" else self.tokenizer.decode_all(input_ids) + + return self._validate_input(request, input_ids, input_text) + + def _validate_input( + self, + request: AnyRequest, + input_ids: list[int], + input_text: str, + ) -> TextTokensPrompt: + token_num = len(input_ids) + + # Note: EmbeddingRequest, ClassificationRequest, + # and ScoreRequest doesn't have max_tokens + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest, + ScoreRequest, RerankRequest, ClassificationRequest)): + + if token_num > self.max_model_len: + operations: dict[type[AnyRequest], str] = { + ScoreRequest: "score", + ClassificationRequest: "classification" + } + operation = operations.get(type(request), + "embedding generation") + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for {operation}. " + f"Please reduce the length of the input.") + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest)): + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # chat completion endpoint supports max_completion_tokens + if isinstance(request, ChatCompletionRequest): + # TODO(#9845): remove max_tokens when field dropped from OpenAI API + max_tokens = request.max_completion_tokens or request.max_tokens + else: + max_tokens = getattr(request, "max_tokens", None) + if max_tokens is None: + if token_num >= self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.") + elif token_num + max_tokens > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.") + + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + def _tokenize_prompt_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_input: Union[str, list[int]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: bool = True, + ) -> TextTokensPrompt: + """ + A simpler implementation of + [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] + that assumes single input. + """ + return next( + self._tokenize_prompt_inputs( + request, + tokenizer, + [prompt_input], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + )) + + def _tokenize_prompt_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_inputs: Iterable[Union[str, list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + A simpler implementation of + [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] + that assumes multiple inputs. + """ + for text in prompt_inputs: + if isinstance(text, str): + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=text, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _tokenize_prompt_input_or_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: bool = True, + ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: + """ + Tokenize/detokenize depending on the input format. + + According to `OpenAI API `_ + , each input can be a string or array of tokens. Note that each request + can pass one or more inputs. + """ + inputs_embeds = list[EmbedsPrompt]() + inputs_text = list[TextTokensPrompt]() + + if (isinstance(request, CompletionRequest) + and request.prompt_embeds is not None): + inputs_embeds.extend( + self._load_prompt_embeds(request.prompt_embeds, + truncate_prompt_tokens)) + + # Empty prompts are okay as long as there are prompt embeddings + if input_or_inputs is None or (inputs_embeds + and input_or_inputs == ""): + return [], inputs_embeds + + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is False" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + inputs_text.extend([ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ]) + + return inputs_text, inputs_embeds + + @overload + async def _preprocess_completion( + self, + request: Union[DetokenizeRequest, EmbeddingCompletionRequest, + RerankRequest, ClassificationRequest, ScoreRequest, + TokenizeCompletionRequest], + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, list[str], list[int], list[list[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: + ... + + @overload + async def _preprocess_completion( + self, + request: CompletionRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., + add_special_tokens: bool = ..., + ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ + EngineTokensPrompt, EngineEmbedsPrompt]]]: + ... + + async def _preprocess_completion( + self, + request: CompletionLikeRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Optional[Union[str, list[str], list[int], + list[list[int]]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, + add_special_tokens: bool = True, + ) -> tuple[Union[list[TextTokensPrompt], list[Union[ + TextTokensPrompt, EmbedsPrompt]]], Union[ + list[EngineTokensPrompt], list[Union[EngineTokensPrompt, + EngineEmbedsPrompt]]]]: + if not isinstance(request, + CompletionRequest) and input_or_inputs is None: + raise ValueError( + "Prompt embeds with non-completion requests is not" + " currently supported.") + + (request_prompts_text, request_prompts_embeds + ) = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + + engine_prompts_text = [ + EngineTokensPrompt( + prompt_token_ids=request_prompt_text["prompt_token_ids"]) + for request_prompt_text in request_prompts_text + ] + + # This check is equivalent to simply checking if + # `request_prompts_embeds` is empty, but it's difficult to propagate + # overloads to the private helper functions to enable this check. + # This overload is needed because only TextPrompts are allowed for + # non-completion requests and if we don't add the overload here, + # everywhere this function is used outside of serving_completion will + # need logic asserting that only text prompts are in the request. + if not isinstance(request, + CompletionRequest) and input_or_inputs is not None: + return request_prompts_text, engine_prompts_text + + engine_prompts_embeds = [ + EngineEmbedsPrompt( + prompt_embeds=request_prompt_embeds["prompt_embeds"]) + for request_prompt_embeds in request_prompts_embeds + ] + + request_prompts = request_prompts_embeds + request_prompts_text + engine_prompts = engine_prompts_embeds + engine_prompts_text + return request_prompts, engine_prompts + + async def _preprocess_chat( + self, + request: ChatLikeRequest, + tokenizer: AnyTokenizer, + messages: list[ChatCompletionMessageParam], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[list[dict[str, Any]]] = None, + documents: Optional[list[dict[str, str]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = False, + ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], + list[EngineTokensPrompt]]: + model_config = self.model_config + + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tool_dicts, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + conversation, mm_data_future = parse_chat_messages_futures( + messages, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + request_prompt: Union[str, list[int]] + if isinstance(tokenizer, MistralTokenizer): + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + **_chat_template_kwargs, + ) + else: + request_prompt = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + + mm_data = await mm_data_future + + # tool parsing is done only if a tool_parser has been set and if + # tool_choice is not "none" (if tool_choice is "none" but a tool_parser + # is set, we want to prevent parsing a tool_call hallucinated by the LLM + should_parse_tools = tool_parser is not None and (hasattr( + request, "tool_choice") and request.tool_choice != "none") + + if should_parse_tools: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request( # type: ignore + request=request) + + if isinstance(request_prompt, str): + prompt_inputs = await self._tokenize_prompt_input_async( + request, + tokenizer, + request_prompt, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), ( + "Prompt has to be either a string or a list of token ids") + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) + + engine_prompt = EngineTokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + + if hasattr(request, "cache_salt") and request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + + return conversation, [request_prompt], [engine_prompt] + + def _load_prompt_embeds( + self, + prompt_embeds: Optional[Union[bytes, list[bytes]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + ) -> list[EmbedsPrompt]: + + def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: + tensor = torch.load(io.BytesIO(base64.b64decode(embed)), + weights_only=True) + assert isinstance( + tensor, + (torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor)) + if tensor.dim() > 2: + tensor = tensor.squeeze(0) + assert tensor.dim() == 2 + if truncate_prompt_tokens is not None: + tensor = tensor[-truncate_prompt_tokens:] + return {"prompt_embeds": tensor} + + if prompt_embeds: + if isinstance(prompt_embeds, list): + return [ + _load_and_validate_embed(embed) for embed in prompt_embeds + ] + else: + return [_load_and_validate_embed(prompt_embeds)] + else: + return [] + + def _log_inputs( + self, + request_id: str, + inputs: RequestPrompt, + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + if self.request_logger is None: + return + prompt, prompt_token_ids, prompt_embeds = None, None, None + if isinstance(inputs, str): + prompt = inputs + elif isinstance(inputs, list): + prompt_token_ids = inputs + elif 'prompt_embeds' in inputs: + prompt_embeds = inputs.get("prompt_embeds") + else: + prompt = inputs["prompt"] + prompt_token_ids = inputs["prompt_token_ids"] + + self.request_logger.log_inputs( + request_id, + prompt, + prompt_token_ids, + prompt_embeds, + params=params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Optional[Mapping[str, str]]: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + @staticmethod + def _base_request_id(raw_request: Optional[Request], + default: Optional[str] = None) -> Optional[str]: + """Pulls the request id to use from a header, if provided""" + default = default or random_uuid() + if raw_request is None: + return default + + return raw_request.headers.get("X-Request-Id", default) + + @staticmethod + def _get_decoded_token(logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False) -> str: + if return_as_token_id: + return f"token_id:{token_id}" + + if logprob.decoded_token is not None: + return logprob.decoded_token + return tokenizer.decode(token_id) + + def _is_model_supported(self, model_name: Optional[str]) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) + + def _get_model_name(self, + model_name: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> str: + if lora_request: + return lora_request.lora_name + if not model_name: + return self.models.base_model_paths[0].name + return model_name + + +def clamp_prompt_logprobs( + prompt_logprobs: Union[PromptLogprobs, + None]) -> Union[PromptLogprobs, None]: + if prompt_logprobs is None: + return prompt_logprobs + + for logprob_dict in prompt_logprobs: + if logprob_dict is None: + continue + for logprob_values in logprob_dict.values(): + if logprob_values.logprob == float('-inf'): + logprob_values.logprob = -9999.0 + return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py new file mode 100644 index 0000000..764b0e7 --- /dev/null +++ b/vllm/entrypoints/openai/serving_models.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import pathlib +from asyncio import Lock +from collections import defaultdict +from dataclasses import dataclass +from http import HTTPStatus +from typing import Optional, Union + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoRAAdapterRequest, + ModelCard, ModelList, + ModelPermission, + UnloadLoRAAdapterRequest) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.utils import AtomicCounter + +logger = init_logger(__name__) + + +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: Optional[str] = None + + +class OpenAIServingModels: + """Shared instance to hold data about the loaded base model(s) and adapters. + + Handles the routes: + - /v1/models + - /v1/load_lora_adapter + - /v1/unload_lora_adapter + """ + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: list[BaseModelPath], + *, + lora_modules: Optional[list[LoRAModulePath]] = None, + prompt_adapters: Optional[list[PromptAdapterPath]] = None, + ): + super().__init__() + + self.base_model_paths = base_model_paths + self.max_model_len = model_config.max_model_len + self.engine_client = engine_client + self.model_config = model_config + + self.static_lora_modules = lora_modules + self.lora_requests: list[LoRARequest] = [] + self.lora_id_counter = AtomicCounter(0) + + self.lora_resolvers: list[LoRAResolver] = [] + for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers( + ): + self.lora_resolvers.append( + LoRAResolverRegistry.get_resolver(lora_resolver_name)) + self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) + + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + async def init_static_loras(self): + """Loads all static LoRA modules. + Raises if any fail to load""" + if self.static_lora_modules is None: + return + for lora in self.static_lora_modules: + load_request = LoadLoRAAdapterRequest(lora_path=lora.path, + lora_name=lora.name) + load_result = await self.load_lora_adapter( + request=load_request, base_model_name=lora.base_model_name) + if isinstance(load_result, ErrorResponse): + raise ValueError(load_result.message) + + def is_base_model(self, model_name) -> bool: + return any(model.name == model_name for model in self.base_model_paths) + + def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + """Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora_request is not None: + return lora_request.lora_name + return self.base_model_paths[0].name + + async def show_available_models(self) -> ModelList: + """Show available models. This includes the base model and all + adapters""" + model_cards = [ + ModelCard(id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()]) + for base_model in self.base_model_paths + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.base_model_paths[0].name, + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) + return ModelList(data=model_cards) + + async def load_lora_adapter( + self, + request: LoadLoRAAdapterRequest, + base_model_name: Optional[str] = None + ) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + lora_request = LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path) + if base_model_name is not None and self.is_base_model(base_model_name): + lora_request.base_model_name = base_model_name + + # Validate that the adapter can be loaded into the engine + # This will also pre-load it for incoming requests + try: + await self.engine_client.add_lora(lora_request) + except BaseException as e: + error_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + if "No adapter found" in str(e): + error_type = "NotFoundError" + status_code = HTTPStatus.NOT_FOUND + + return create_error_response(message=str(e), + err_type=error_type, + status_code=status_code) + + self.lora_requests.append(lora_request) + logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name, + lora_path) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + logger.info("Removed LoRA adapter: name '%s'", lora_name) + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + async def _check_load_lora_adapter_request( + self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been " + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + return None + + async def resolve_lora( + self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: + """Attempt to resolve a LoRA adapter using available resolvers. + + Args: + lora_name: Name/identifier of the LoRA adapter + + Returns: + LoRARequest if found and loaded successfully. + ErrorResponse (404) if no resolver finds the adapter. + ErrorResponse (400) if adapter(s) are found but none load. + """ + async with self.lora_resolver_lock[lora_name]: + # First check if this LoRA is already loaded + for existing in self.lora_requests: + if existing.lora_name == lora_name: + return existing + + base_model_name = self.model_config.model + unique_id = self.lora_id_counter.inc(1) + found_adapter = False + + # Try to resolve using available resolvers + for resolver in self.lora_resolvers: + lora_request = await resolver.resolve_lora( + base_model_name, lora_name) + + if lora_request is not None: + found_adapter = True + lora_request.lora_int_id = unique_id + + try: + await self.engine_client.add_lora(lora_request) + self.lora_requests.append(lora_request) + logger.info( + "Resolved and loaded LoRA adapter '%s' using %s", + lora_name, resolver.__class__.__name__) + return lora_request + except BaseException as e: + logger.warning( + "Failed to load LoRA '%s' resolved by %s: %s. " + "Trying next resolver.", lora_name, + resolver.__class__.__name__, e) + continue + + if found_adapter: + # An adapter was found, but all attempts to load it failed. + return create_error_response( + message=(f"LoRA adapter '{lora_name}' was found " + "but could not be loaded."), + err_type="BadRequestError", + status_code=HTTPStatus.BAD_REQUEST) + else: + # No adapter was found + return create_error_response( + message=f"LoRA adapter {lora_name} does not exist", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py new file mode 100644 index 0000000..c2ed50d --- /dev/null +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -0,0 +1,234 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import base64 +import time +from collections.abc import AsyncGenerator +from typing import Final, Literal, Optional, Union, cast + +import jinja2 +import numpy as np +import torch +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, + PoolingChatRequest, + PoolingRequest, PoolingResponse, + PoolingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.utils import _validate_truncation_size +from vllm.logger import init_logger +from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.utils import merge_async_iterators + +logger = init_logger(__name__) + + +def _get_data( + output: PoolingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[list[float], str]: + if encoding_format == "float": + return output.data.tolist() + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + pt_float32 = output.data.to(dtype=torch.float32) + pooling_bytes = np.array(pt_float32, dtype="float32").tobytes() + return base64.b64encode(pooling_bytes).decode("utf-8") + + assert_never(encoding_format) + + +class OpenAIServingPooling(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_pooling( + self, + request: PoolingRequest, + raw_request: Optional[Request] = None, + ) -> Union[PoolingResponse, ErrorResponse]: + """ + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + encoding_format = request.encoding_format + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = self._get_model_name(request.model) + request_id = f"pool-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + truncate_prompt_tokens = request.truncate_prompt_tokens + + try: + truncate_prompt_tokens = _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens) + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for pooling models") + + if isinstance(request, PoolingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + # In pooling requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError, jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators(*generators) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: list[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(list[PoolingRequestOutput], + final_res_batch) + + response = self.request_output_to_pooling_response( + final_res_batch_checked, + request_id, + created_time, + model_name, + encoding_format, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def request_output_to_pooling_response( + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + encoding_format: Literal["float", "base64"], + ) -> PoolingResponse: + items: list[PoolingResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + item = PoolingResponseData( + index=idx, + data=_get_data(final_res.outputs, encoding_format), + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return PoolingResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py new file mode 100644 index 0000000..328d4ff --- /dev/null +++ b/vllm/entrypoints/openai/serving_score.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import time +from collections.abc import AsyncGenerator, Mapping +from typing import Any, Optional, Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, + RerankRequest, RerankResponse, + RerankResult, RerankUsage, + ScoreRequest, ScoreResponse, + ScoreResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.score_utils import (_cosine_similarity, + _validate_score_input_lens) +from vllm.entrypoints.utils import _validate_truncation_size +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import make_async, merge_async_iterators + +logger = init_logger(__name__) + + +class ServingScores(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + async def _embedding_score( + self, + tokenizer: AnyTokenizer, + texts_1: list[str], + texts_2: list[str], + request: Union[RerankRequest, ScoreRequest], + request_id=str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> list[PoolingRequestOutput]: + + input_texts = texts_1 + texts_2 + + engine_prompts: list[TokensPrompt] = [] + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + + tokenization_kwargs = tokenization_kwargs or {} + tokenized_prompts = await asyncio.gather( + *(tokenize_async(t, **tokenization_kwargs) for t in input_texts)) + + for tok_result, input_text in zip(tokenized_prompts, input_texts): + + text_token_prompt = \ + self._validate_input( + request, + tok_result["input_ids"], + input_text) + + engine_prompts.append( + TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"])) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + input_texts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + generators.append( + self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + )) + + result_generator = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: list[PoolingRequestOutput] = [] + + embeddings: list[Optional[PoolingRequestOutput]] =\ + [None] * len(engine_prompts) + + async for i, res in result_generator: + embeddings[i] = res + + emb_texts_1: list[PoolingRequestOutput] = [] + emb_texts_2: list[PoolingRequestOutput] = [] + + for i in range(0, len(texts_1)): + assert (emb := embeddings[i]) is not None + emb_texts_1.append(emb) + + for i in range(len(texts_1), len(embeddings)): + assert (emb := embeddings[i]) is not None + emb_texts_2.append(emb) + + if len(emb_texts_1) == 1: + emb_texts_1 = emb_texts_1 * len(emb_texts_2) + + final_res_batch = _cosine_similarity(tokenizer=tokenizer, + embed_1=emb_texts_1, + embed_2=emb_texts_2) + + return final_res_batch + + async def _cross_encoding_score( + self, + tokenizer: AnyTokenizer, + texts_1: list[str], + texts_2: list[str], + request: Union[RerankRequest, ScoreRequest], + request_id=str, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[Union[LoRARequest, None]] = None, + prompt_adapter_request: Optional[Union[PromptAdapterRequest, + None]] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> list[PoolingRequestOutput]: + + request_prompts: list[str] = [] + engine_prompts: list[TokensPrompt] = [] + + if len(texts_1) == 1: + texts_1 = texts_1 * len(texts_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)] + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + + tokenization_kwargs = tokenization_kwargs or {} + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) + + for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): + sep_token = tokenizer.sep_token if tokenizer.sep_token else '' + request_prompt = f"{t1}{sep_token}{t2}" + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + pooling_params = request.to_pooling_params(use_cross_encoder=True) + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + + result_generator = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: list[ + Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) + + async for i, res in result_generator: + final_res_batch[i] = res + + return [out for out in final_res_batch if out is not None] + + async def _run_scoring( + self, + texts_1: Union[str, list[str]], + texts_2: Union[str, list[str]], + request: Union[ScoreRequest, RerankRequest], + request_id: str, + raw_request: Optional[Request] = None, + truncate_prompt_tokens: Optional[int] = None, + ) -> list[PoolingRequestOutput]: + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for scoring models") + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, + tokenization_kwargs) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(texts_1, str): + texts_1 = [texts_1] + if isinstance(texts_2, str): + texts_2 = [texts_2] + + _validate_score_input_lens(texts_1, texts_2) + + if self.model_config.is_cross_encoder: + return await self._cross_encoding_score( + tokenizer=tokenizer, + texts_1=texts_1, + texts_2=texts_2, + request=request, + request_id=request_id, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers) + + else: + return await self._embedding_score( + tokenizer=tokenizer, + texts_1=texts_1, + texts_2=texts_2, + request=request, + request_id=request_id, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"score-{self._base_request_id(raw_request)}" + created_time = int(time.time()) + + try: + final_res_batch = await self._run_scoring( + request.text_1, + request.text_2, + request, + request_id, + raw_request, + request.truncate_prompt_tokens, + ) + + return self.request_output_to_score_response( + final_res_batch, + request_id, + created_time, + self._get_model_name(request.model), + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def do_rerank( + self, + request: RerankRequest, + raw_request: Optional[Request] = None + ) -> Union[RerankResponse, ErrorResponse]: + """ + Rerank API based on JinaAI's rerank API; implements the same + API interface. Designed for compatibility with off-the-shelf + tooling, since this is a common standard for reranking APIs + + See example client implementations at + https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py + numerous clients use this standard. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"rerank-{self._base_request_id(raw_request)}" + documents = request.documents + top_n = request.top_n if request.top_n > 0 else len(documents) + + try: + final_res_batch = await self._run_scoring( + request.query, + documents, + request, + request_id, + raw_request, + request.truncate_prompt_tokens, + ) + return self.request_output_to_rerank_response( + final_res_batch, + request_id, + self._get_model_name(request.model), + documents, + top_n, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def request_output_to_score_response( + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + created_time: int, + model_name: str, + ) -> ScoreResponse: + items: list[ScoreResponseData] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + item = ScoreResponseData( + index=idx, + score=classify_res.outputs.score, + ) + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ScoreResponse( + id=request_id, + created=created_time, + model=model_name, + data=items, + usage=usage, + ) + + def request_output_to_rerank_response( + self, final_res_batch: list[PoolingRequestOutput], request_id: str, + model_name: str, documents: list[str], + top_n: int) -> RerankResponse: + """ + Convert the output of do_rank to a RerankResponse + """ + results: list[RerankResult] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + classify_res = ScoringRequestOutput.from_base(final_res) + + result = RerankResult( + index=idx, + document=RerankDocument(text=documents[idx]), + relevance_score=classify_res.outputs.score, + ) + results.append(result) + prompt_token_ids = final_res.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + + # sort by relevance, then return the top n if set + results.sort(key=lambda x: x.relevance_score, reverse=True) + if top_n < len(documents): + results = results[:top_n] + + return RerankResponse( + id=request_id, + model=model_name, + results=results, + usage=RerankUsage(total_tokens=num_prompt_tokens)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py new file mode 100644 index 0000000..3db0a71 --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Final, Optional, Union + +import jinja2 +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, + TokenizeResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class OpenAIServingTokenization(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + async def create_tokenize( + self, + request: TokenizeRequest, + raw_request: Request, + ) -> Union[TokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{self._base_request_id(raw_request)}" + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if isinstance(request, TokenizeChatRequest): + tool_dicts = (None if request.tools is None else + [tool.model_dump() for tool in request.tools]) + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + tool_dicts=tool_dicts, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + chat_template_kwargs=request.chat_template_kwargs, + add_special_tokens=request.add_special_tokens, + ) + else: + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) + except (ValueError, TypeError, jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(f"{e} {e.__cause__}") + + input_ids: list[int] = [] + for i, engine_prompt in enumerate(engine_prompts): + self._log_inputs(request_id, + request_prompts[i], + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect + # tokenization (Unlike in Embeddings API where an error is raised) + if isinstance(engine_prompt, + dict) and "prompt_token_ids" in engine_prompt: + input_ids.extend(engine_prompt["prompt_token_ids"]) + + token_strs = None + if request.return_token_strs: + token_strs = tokenizer.convert_ids_to_tokens(input_ids) + + return TokenizeResponse(tokens=input_ids, + token_strs=token_strs, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, + request: DetokenizeRequest, + raw_request: Request, + ) -> Union[DetokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{self._base_request_id(raw_request)}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect tokenization + # (Unlike in Embeddings API where an error is raised) + + prompt_input = await self._tokenize_prompt_input_async( + request, + tokenizer, + request.tokens, + ) + input_text = prompt_input["prompt"] + + return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py new file mode 100644 index 0000000..0d6989f --- /dev/null +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import AsyncGenerator +from typing import Optional, Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, RequestResponseMetadata, TranscriptionRequest, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, TranslationRequest, TranslationResponse, + TranslationResponseStreamChoice, TranslationStreamResponse) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText +from vllm.logger import init_logger +from vllm.outputs import RequestOutput + +logger = init_logger(__name__) + + +class OpenAIServingTranscription(OpenAISpeechToText): + """Handles transcription requests.""" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="transcribe") + + async def create_transcription( + self, audio_data: bytes, request: TranscriptionRequest, + raw_request: Request + ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], + ErrorResponse]: + """Transcription API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranscription + for the API specification. This API mimics the OpenAI transcription API. + """ + return await self._create_speech_to_text( + audio_data=audio_data, + request=request, + raw_request=raw_request, + response_class=TranscriptionResponse, + stream_generator_method=self.transcription_stream_generator, + ) + + async def transcription_stream_generator( + self, request: TranscriptionRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + generator = self._speech_to_text_stream_generator( + request=request, + list_result_generator=result_generator, + request_id=request_id, + request_metadata=request_metadata, + audio_duration_s=audio_duration_s, + chunk_object_type="transcription.chunk", + response_stream_choice_class=TranscriptionResponseStreamChoice, + stream_response_class=TranscriptionStreamResponse, + ) + async for chunk in generator: + yield chunk + + +class OpenAIServingTranslation(OpenAISpeechToText): + """Handles translation requests.""" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="translate") + + async def create_translation( + self, audio_data: bytes, request: TranslationRequest, + raw_request: Request + ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]: + """Translation API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/audio/createTranslation + for the API specification. This API mimics the OpenAI translation API. + """ + return await self._create_speech_to_text( + audio_data=audio_data, + request=request, + raw_request=raw_request, + response_class=TranslationResponse, + stream_generator_method=self.translation_stream_generator, + ) + + async def translation_stream_generator( + self, request: TranslationRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + generator = self._speech_to_text_stream_generator( + request=request, + list_result_generator=result_generator, + request_id=request_id, + request_metadata=request_metadata, + audio_duration_s=audio_duration_s, + chunk_object_type="translation.chunk", + response_stream_choice_class=TranslationResponseStreamChoice, + stream_response_class=TranslationStreamResponse, + ) + async for chunk in generator: + yield chunk diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py new file mode 100644 index 0000000..0ab029e --- /dev/null +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import io +import math +import time +from collections.abc import AsyncGenerator +from functools import cached_property +from math import ceil +from typing import Callable, Literal, Optional, TypeVar, Union, cast + +import numpy as np +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, ErrorResponse, RequestResponseMetadata, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, TranslationResponse, + TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + SpeechToTextRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_cls +from vllm.model_executor.models import SupportsTranscription +from vllm.outputs import RequestOutput +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import PlaceholderModule + +try: + import librosa +except ImportError: + librosa = PlaceholderModule("librosa") # type: ignore[assignment] + +SpeechToTextResponse = Union[TranscriptionResponse, TranslationResponse] +T = TypeVar("T", bound=SpeechToTextResponse) + +logger = init_logger(__name__) + +# As per https://platform.openai.com/docs/guides/speech-to-text#overview. +# TODO configurable +MAX_AUDIO_CLIP_FILESIZE_MB = 25 +MAX_AUDIO_CLIP_SECONDS = 30 +OVERLAP_CHUNK_SECOND = 1 +MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio + + +class OpenAISpeechToText(OpenAIServing): + """Base class for speech-to-text operations like transcription and + translation.""" + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + task_type: Literal["transcribe", "translate"] = "transcribe", + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + processor = cached_get_processor(model_config.model) + self.max_audio_clip_s = processor.feature_extractor.chunk_length \ + if hasattr(processor.feature_extractor, 'chunk_length') \ + else MAX_AUDIO_CLIP_SECONDS + self.model_sr = processor.feature_extractor.sampling_rate + self.hop_length = processor.feature_extractor.hop_length + self.task_type = task_type + + if self.default_sampling_params: + logger.info( + "Overwriting default completion sampling param with: %s", + self.default_sampling_params) + + @cached_property + def model_cls(self): + return get_model_cls(self.model_config) + + async def _preprocess_speech_to_text( + self, + request: SpeechToTextRequest, + audio_data: bytes, + ) -> tuple[list[PromptType], float]: + model_cls = cast(SupportsTranscription, self.model_cls) + + # Validate request + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + lang = request.language or "en" + model_cls.validate_language(lang) + + if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + raise ValueError("Maximum file size exceeded.") + + with io.BytesIO(audio_data) as bytes_: + # NOTE resample to model SR here for efficiency. This is also a + # pre-requisite for chunking, as it assumes Whisper SR. + y, sr = librosa.load(bytes_, sr=self.model_sr) + + duration = librosa.get_duration(y=y, sr=sr) + chunks = [y + ] if duration < self.max_audio_clip_s else self._split_audio( + y, int(sr)) + prompts = [] + for chunk in chunks: + prompt = { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "audio": (chunk, sr), + }, + }, + "decoder_prompt": + model_cls.get_decoder_prompt(lang, self.task_type, + request.prompt) + } + prompts.append(cast(PromptType, prompt)) + return prompts, duration + + async def _create_speech_to_text( + self, + audio_data: bytes, + request: SpeechToTextRequest, + raw_request: Request, + response_class: type[T], + stream_generator_method: Callable[..., AsyncGenerator[str, None]], + ) -> Union[T, AsyncGenerator[str, None], ErrorResponse]: + """Base method for speech-to-text operations like transcription and + translation.""" + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + if request.response_format not in ['text', 'json']: + return self.create_error_response( + "Currently only support response_format `text` or `json`") + + request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + if lora_request: + return self.create_error_response( + "Currently do not support LoRA for " + f"{self.task_type.title()}.") + if prompt_adapter_request: + return self.create_error_response( + f"Currently do not support PromptAdapter for " + f"{self.task_type.title()}.") + + prompts, duration_s = await self._preprocess_speech_to_text( + request=request, + audio_data=audio_data, + ) + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + list_result_generator: Optional[list[AsyncGenerator[RequestOutput, + None]]] = None + try: + # Unlike most decoder-only models, whisper generation length is not + # constrained by the size of the input audio, which is mapped to a + # fixed-size log-mel-spectogram. + default_max_tokens = self.model_config.max_model_len + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + self._log_inputs( + request_id, + prompts[0]['decoder_prompt'], # type: ignore + params=sampling_params, + lora_request=None, + prompt_adapter_request=None) + + list_result_generator = [ + self.engine_client.generate( + prompt, + sampling_params, + request_id, + ) for prompt in prompts + ] + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + if request.stream: + return stream_generator_method(request, list_result_generator, + request_id, request_metadata, + duration_s) + # Non-streaming response. + try: + assert list_result_generator is not None + text = "" + for result_generator in list_result_generator: + async for op in result_generator: + text += op.outputs[0].text + return cast(T, response_class(text=text)) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + async def _speech_to_text_stream_generator( + self, + request: SpeechToTextRequest, + list_result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + chunk_object_type: Literal["translation.chunk", "transcription.chunk"], + response_stream_choice_class: Union[ + type[TranscriptionResponseStreamChoice], + type[TranslationResponseStreamChoice]], + stream_response_class: Union[type[TranscriptionStreamResponse], + type[TranslationStreamResponse]], + ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + model_name = request.model + + completion_tokens = 0 + num_prompt_tokens = 0 + + include_usage = request.stream_include_usage \ + if request.stream_include_usage else False + include_continuous_usage = request.stream_continuous_usage_stats\ + if include_usage and request.stream_continuous_usage_stats\ + else False + + try: + for result_generator in list_result_generator: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token + # is not specified. + num_prompt_tokens = max( + len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder + # prompts directly at least not to Whisper. + # One indicator of the encoder amount of processing + # is the log-mel spectogram length. + num_prompt_tokens += ceil( + audio_duration_s * self.model_sr / self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + + # Just one output (n=1) supported. + assert len(res.outputs) == 1 + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = response_stream_choice_class( + delta=delta_message) + else: + # Model is finished generating. + choice_data = response_stream_choice_class( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = stream_response_class(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Once the final token is handled, if stream_options.include_usage + # is sent, send the usage. + if include_usage: + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + + final_usage_chunk = stream_response_class( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in %s stream generator.", self.task_type) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + def _split_audio(self, audio_data: np.ndarray, + sample_rate: int) -> list[np.ndarray]: + chunk_size = sample_rate * self.max_audio_clip_s + overlap_size = sample_rate * OVERLAP_CHUNK_SECOND + chunks = [] + i = 0 + while i < audio_data.shape[-1]: + if i + chunk_size >= audio_data.shape[-1]: + # handle last chunk + chunks.append(audio_data[..., i:]) + break + + # Find the best split point in the overlap region + search_start = i + chunk_size - overlap_size + search_end = min(i + chunk_size, audio_data.shape[-1]) + split_point = self._find_split_point(audio_data, search_start, + search_end) + + # Extract chunk up to the split point + chunks.append(audio_data[..., i:split_point]) + i = split_point + return chunks + + def _find_split_point(self, wav: np.ndarray, start_idx: int, + end_idx: int) -> int: + """Find the best point to split audio by + looking for silence or low amplitude. + Args: + wav: Audio tensor [1, T] + start_idx: Start index of search region + end_idx: End index of search region + Returns: + Index of best splitting point + """ + segment = wav[start_idx:end_idx] + + # Calculate RMS energy in small windows + min_energy = math.inf + quietest_idx = 0 + for i in range(0, + len(segment) - MIN_ENERGY_WINDOW_SIZE, + MIN_ENERGY_WINDOW_SIZE): + window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] + energy = (window**2).mean()**0.5 + if energy < min_energy: + quietest_idx = i + start_idx + min_energy = energy + return quietest_idx diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000..a0b75f8 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .abstract_tool_parser import ToolParser, ToolParserManager +from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .glm4_moe_tool_parser import Glm4MoeModelToolParser +from .granite_20b_fc_tool_parser import Granite20bFCToolParser +from .granite_tool_parser import GraniteToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser +from .jamba_tool_parser import JambaToolParser +from .llama4_pythonic_tool_parser import Llama4PythonicToolParser +from .llama_tool_parser import Llama3JsonToolParser +from .minimax_tool_parser import MinimaxToolParser +from .mistral_tool_parser import MistralToolParser +from .phi4mini_tool_parser import Phi4MiniJsonToolParser +from .pythonic_tool_parser import PythonicToolParser +from .step3_tool_parser import Step3ToolParser +from .xlam_tool_parser import xLAMToolParser + +__all__ = [ + "ToolParser", + "ToolParserManager", + "Granite20bFCToolParser", + "GraniteToolParser", + "Hermes2ProToolParser", + "MistralToolParser", + "Internlm2ToolParser", + "Llama3JsonToolParser", + "JambaToolParser", + "Llama4PythonicToolParser", + "PythonicToolParser", + "Phi4MiniJsonToolParser", + "DeepSeekV3ToolParser", + "Step3ToolParser", + "xLAMToolParser", + "MinimaxToolParser", + "Glm4MoeModelToolParser", +] diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000..02aeab6 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from collections.abc import Sequence +from functools import cached_property +from typing import Callable, Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import import_from_path, is_list_of + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: list[dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [] + + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") + + +class ToolParserManager: + tool_parsers: dict[str, type] = {} + + @classmethod + def get_tool_parser(cls, name) -> type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, plugin_path) + return diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py new file mode 100644 index 0000000..da4760a --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v3") +class DeepSeekV3ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" + self.tool_calls_end_token: str = "<|tool▁calls▁end|>" + + self.tool_call_start_token: str = "<|tool▁call▁begin|>" + self.tool_call_end_token: str = "<|tool▁call▁end|>" + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P.*)<|tool▁sep|>(?P.*)\n```json\n(?P.*)\n```<|tool▁call▁end|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P.*)<|tool▁sep|>(?P.*)\n```json\n(?P.*[^\n`])" + ) + + self.stream_tool_call_name_regex = re.compile( + r"(?P.*)<|tool▁sep|>(?P.*)\n") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + tool_type, function_name, function_args = match + tool_calls.append( + ToolCall( + type=tool_type, + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_type, tool_name, tool_args = ( + current_tool_call_matches.groups()) + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_type, tool_name = ( + current_tool_call_name_matches.groups()) + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py new file mode 100644 index 0000000..c40788d --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# code modified from deepseekv3_tool_parser.py + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("glm4_moe") +class Glm4MoeModelToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + self.tool_calls_start_token = self.tool_call_start_token + + # Updated regex for the XML-based format + self.tool_call_regex = re.compile( + r"\s*" + r"(?P[^\n<]+)\s*" # 函数名(到换行或 <) + r"(?P(?:\s*[^<]+\s*" + r"[^<]*\s*)*)\s*" + r"", + re.DOTALL, + ) + + # Regex for parsing individual arguments + self.arg_regex = re.compile( + r"(?P[^<]+)\s*(?P[^<]*)", + re.DOTALL, + ) + + # Streaming regex + self.stream_tool_call_portion_regex = re.compile( + r"(?P[^\n<]+)\s*" + r"(?P(?:\s*[^<]+\s*" + r"[^<]*\s*)*)", + re.DOTALL, + ) + + # For streaming, we also need a regex to match just the function name + self.stream_tool_call_name_regex = re.compile( + r"(?P[^\n<]+)", + re.DOTALL, + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + def _parse_arguments(self, args_text: str) -> str: + """Parse XML-based arguments into JSON format.""" + if not args_text or not args_text.strip(): + return "{}" + + args_dict = {} + matches = self.arg_regex.findall(args_text) + + for key, value in matches: + args_dict[key.strip()] = value.strip() + + import json + return json.dumps(args_dict, ensure_ascii=False) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # Find all tool calls in the output + function_call_matches = self.tool_call_regex.findall(model_output) + + logger.debug("function_call_matches: %s", function_call_matches) + + if not function_call_matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls = [] + for i, match in enumerate(function_call_matches): + function_name, function_args_xml = match + function_name = function_name.strip() + + # Parse XML arguments to JSON + function_args_json = self._parse_arguments(function_args_xml) + + tool_calls.append( + ToolCall( + id=f"call_{i}", + type='function', + function=FunctionCall(name=function_name, + arguments=function_args_json), + )) + + # Extract content before the first tool call + content = model_output[:model_output.find(self. + tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content.strip() if content.strip() else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_call_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_id, tool_args = (current_tool_call_matches.groups()) + tool_name = tool_id.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_id_str, = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id_str + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + tool_id = current_tool_call.get("id") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py new file mode 100644 index 0000000..5508ba6 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from json import JSONDecoder +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite-20b-fc") +class Granite20bFCToolParser(ToolParser): + """ + Tool call parser for the granite-20b-functioncalling model intended + for use with the examples/tool_chat_template_granite20b_fc.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.bot_token = "" + self.tool_start_token = self.bot_token + self.tool_call_regex = re.compile(r"\s*") + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + if self.tool_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + dec = JSONDecoder() + try: + matches = list(self.tool_call_regex.finditer(model_output)) + logger.debug("Found %d tool call matches", len(matches)) + + raw_function_calls = [] + + for i, match in enumerate(matches): + # position after the tag + start_of_json = match.end() + # end_index == the start of the next function call + # (if exists) + next_function_call_start = (matches[i + 1].start() if i + + 1 < len(matches) else None) + + raw_function_calls.append( + dec.raw_decode( + model_output[start_of_json:next_function_call_start]) + [0]) + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), + ), + ) for function_call in raw_function_calls + ] + + content = model_output[:model_output.find(self.bot_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if len(current_text) < len( + self.bot_token) and self.bot_token.startswith(current_text): + return None + + if not current_text.startswith(self.bot_token): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + start_idx = len(self.bot_token) + start_idx = consume_space(start_idx, current_text) + + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + start_idx = consume_space(start_idx, current_text) + start_idx += len(self.bot_token) + start_idx = consume_space(start_idx, current_text) + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py new file mode 100644 index 0000000..fcc5b7e --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("granite") +class GraniteToolParser(ToolParser): + """ + Tool call parser for the granite 3.0 models. Intended + for use with the examples/tool_chat_template_granite.jinja + template. + + Used when --enable-auto-tool-choice --tool-call-parser granite + are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + # for granite 3.0, the token `<|tool_call|>` + self.bot_token = "<|tool_call|>" + # for granite 3.1, the string `` + self.bot_string = "" + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + stripped = model_output.strip()\ + .removeprefix(self.bot_token)\ + .removeprefix(self.bot_string)\ + .lstrip() + if not stripped or stripped[0] != '[': + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + raw_function_calls = json.loads(stripped) + if not isinstance(raw_function_calls, list): + raise Exception( + f"Expected dict or list, got {type(raw_function_calls)}") + + logger.debug("Extracted %d tool calls", len(raw_function_calls)) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), + ), + ) for function_call in raw_function_calls + ] + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=None, + ) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + start_idx = consume_space(0, current_text) + if current_text[start_idx:].startswith(self.bot_token): + start_idx = consume_space(start_idx + len(self.bot_token), + current_text) + if current_text[start_idx:].startswith(self.bot_string): + start_idx = consume_space(start_idx + len(self.bot_string), + current_text) + if not current_text or start_idx >= len(current_text)\ + or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = None + is_complete = None + try: + tool_calls, end_idx = partial_json_loads( + current_text[start_idx:], flags) + if type(tool_calls) is list: + tool_call_arr = tool_calls + else: + return DeltaMessage(content=delta_text) + + is_complete = [True] * len(tool_calls) + if not is_complete_json( + current_text[start_idx:start_idx + end_idx]): + is_complete[-1] = False + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if not tool_call_arr: + return None + + # select as the current tool call the one we're on the state at + current_tool_call: dict = tool_call_arr[self.current_tool_id] + + delta = None + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + if len(tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + if cur_args_json != prev_args_json: + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000..c7030d3 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("hermes") +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if (self.prev_tool_call_arr is None + or len(self.prev_tool_call_arr) == 0): + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = diff.encode('utf-8').decode( + 'unicode_escape') if diff is str else diff + if ('"}' not in delta_text): + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + except json.decoder.JSONDecodeError: + logger.debug("unable to parse JSON") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if (current_tool_call is None): + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + if (delta_text not in cur_arguments_json[:-2]): + return None + args_delta_start_loc = cur_arguments_json[:-2]. \ + rindex(delta_text) + \ + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if isinstance(delta_text, str) and len(delta_text.rstrip( + )) >= 1 and delta_text.rstrip()[-1] == '}': + delta_text = delta_text.rstrip()[:-1] + + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_text).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += delta_text + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 0000000..92004de --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_arguments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + # if the tool call is sended, return a empty delta message + # to make sure the finish_reason will be send correctly. + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() + action = action.split('<|action_end|>'.strip())[0] + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls + try: + tool_call_arr: dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.get_arguments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_arguments(tool_call_arr) + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = self.get_arguments(tool_call_arr) + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {})), + ensure_ascii=False) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py new file mode 100644 index 0000000..66b483d --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizers import MistralTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("jamba") +class JambaToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + raise ValueError( + "Detected a MistralTokenizer tokenizer when using a Jamba model" + ) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "" + self.tool_calls_end_token: str = "" + + self.tool_calls_regex = re.compile( + rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", + re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "Jamba Tool parser could not locate tool calls start/end " + "tokens in the tokenizer!") + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because jamba use the special + # tokens to indicate the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # use a regex to find the tool call between the tags + function_calls = self.tool_calls_regex.findall(model_output)[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = json.loads(function_calls) + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False), + )) for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if + (len(content) > 0 and content != " ") else None) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.tool_calls_start_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the start of tool calls token which means + # the start of tool calling + if (self.tool_calls_start_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion and don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # Extract the tool calls between the special tool call tokens + parsable_arr = current_text.split( + self.tool_calls_start_token)[-1].split( + self.tool_calls_end_token)[0] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: list[dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff, ensure_ascii=False).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py new file mode 100644 index 0000000..6bf44a4 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import json +from collections.abc import Sequence +from typing import Any, Union + +import regex as re +from transformers import PreTrainedTokenizerBase + +import vllm.envs as envs +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("llama4_pythonic") +class Llama4PythonicToolParser(ToolParser): + """ + Toolcall parser for Llama4 that produce tool calls in a pythonic style + Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + # remove <|python_start|> and <|python_end|> + # as Llama 4 model sometime will output those tokens + if model_output.startswith("<|python_start|>"): + model_output = model_output[len("<|python_start|>"):] + model_output = model_output.replace("<|python_end|>", "") + + is_tool_call_pattern = False + try: + is_tool_call_pattern = self.TOOL_CALL_REGEX.match( + model_output, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + except TimeoutError: + logger.warning( + "Regex timeout occurred when matching tool call pattern.") + logger.debug("Regex timeout occurred when matching user input: %s", + model_output) + + if not is_tool_call_pattern: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("[") and not current_text.startswith( + "<|python_start|>"): + return DeltaMessage(content=delta_text) + + try: + # remove <|python_start|> and <|python_end|> + if current_text.startswith("<|python_start|>"): + current_text = current_text[len("<|python_start|>"):] + if current_text.endswith("<|python_end|>"): + current_text = current_text[:current_text. + rfind("<|python_end|>")] + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py new file mode 100644 index 0000000..5698bc7 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from json import JSONDecoder +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, + is_complete_json, + partial_json_loads) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("llama3_json") +@ToolParserManager.register_module("llama4_json") +class Llama3JsonToolParser(ToolParser): + """ + Tool call parser for Llama 3.1 models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser llama3_json + are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "<|python_tag|>" + self.bot_token_id = tokenizer.encode(self.bot_token, + add_special_tokens=False)[0] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + # case -- if a tool call token is not present, return a text response + if not (model_output.startswith(self.bot_token) + or model_output.startswith('{')): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # load the JSON, and then use it to build the Function and + # Tool Call + dec = JSONDecoder() + function_call_arr = [] + + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if model_output.startswith( + self.bot_token) else 0 + while start_idx < len(model_output): + (obj, end_idx) = dec.raw_decode(model_output[start_idx:]) + start_idx += end_idx + len('; ') + function_call_arr.append(obj) + + tool_calls: list[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"] \ + if "arguments" in raw_function_call \ + else raw_function_call["parameters"], + ensure_ascii=False))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not (current_text.startswith(self.bot_token) + or current_text.startswith('{')): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if current_text.startswith( + self.bot_token) else 0 + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + len('; ') + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert "arguments" not in obj, \ + "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py new file mode 100644 index 0000000..6ba32e3 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("minimax") +class MinimaxToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [] + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + + # Add regex pattern for thinking tag + self.thinking_tag_pattern = r"(.*?)" + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + logger.warning( + "Minimax Tool parser could not locate tool call start/end " + "tokens in the tokenizer. Falling back to string matching.") + + def preprocess_model_output(self, model_output: str) -> str: + """ + Remove tool calls from within thinking tags to avoid processing them. + """ + + def remove_tool_calls_from_think(match): + think_content = match.group(1) + # Remove tool_calls from within the think tag + cleaned_content = re.sub(r".*?", + "", + think_content, + flags=re.DOTALL) + return f"{cleaned_content}" + + # Process thinking tags and remove tool_calls from within them + processed_output = re.sub(self.thinking_tag_pattern, + remove_tool_calls_from_think, + model_output, + flags=re.DOTALL) + + return processed_output + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # Preprocess to remove tool calls from thinking tags + processed_output = self.preprocess_model_output(model_output) + + if self.tool_call_start_token not in processed_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_call_tuples = ( + self.tool_call_regex.findall(processed_output)) + + raw_function_calls = [] + for match in function_call_tuples: + tool_call_content = match[0] if match[0] else match[1] + if tool_call_content.strip(): + lines = tool_call_content.strip().split('\n') + for line in lines: + line = line.strip() + if line and line.startswith('{') and line.endswith( + '}'): + try: + parsed_call = json.loads(line) + raw_function_calls.append(parsed_call) + except json.JSONDecodeError: + continue + + tool_calls = [] + for function_call in raw_function_calls: + if "name" in function_call and "arguments" in function_call: + tool_calls.append( + ToolCall(type="function", + function=FunctionCall( + name=function_call["name"], + arguments=json.dumps( + function_call["arguments"], + ensure_ascii=False)))) + + # Extract content before the first valid tool call + # Find the position in processed output, then map back to original + processed_pos = processed_output.find(self.tool_call_start_token) + if processed_pos != -1: + # Get the content before tool calls in processed output + processed_content = processed_output[:processed_pos].strip() + + if processed_content: + # Find the end of this content in the original output + # Look for the last non-empty line of processed content + lines = processed_content.split('\n') + for line in reversed(lines): + line = line.strip() + if line: + # Find this line in original output + pos = model_output.find(line) + if pos != -1: + content = model_output[:pos + len(line)] + break + else: + content = "" + else: + content = "" + else: + content = model_output + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content.strip() if content.strip() else None) + + except Exception: + logger.exception( + "An unexpected error occurred during tool call extraction.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + + # Preprocess to remove tool calls from thinking tags + processed_current_text = self.preprocess_model_output(current_text) + + if self.tool_call_start_token not in processed_current_text: + return DeltaMessage(content=delta_text) + + if (self.tool_call_start_token_id is not None + and self.tool_call_start_token_id in delta_token_ids + and len(delta_token_ids) == 1): + return None + + original_tool_call_start_pos = current_text.find( + self.tool_call_start_token) + if original_tool_call_start_pos > 0: + delta_start_pos = len(current_text) - len(delta_text) + if delta_start_pos < original_tool_call_start_pos: + content_part = delta_text + if delta_start_pos + len( + delta_text) > original_tool_call_start_pos: + content_part = delta_text[:original_tool_call_start_pos - + delta_start_pos] + if content_part: + return DeltaMessage(content=content_part) + + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_content = processed_current_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0] + + tool_call_arr = [] + if parsable_content.strip(): + lines = parsable_content.strip().split('\n') + for line in lines: + line = line.strip() + if line and (line.startswith('{') or '"name"' in line): + try: + if line.endswith('}'): + parsed_call = json.loads(line) + tool_call_arr.append(parsed_call) + else: + parsed_call = partial_json_parser.loads( + line, flags) + if parsed_call and isinstance( + parsed_call, dict): + tool_call_arr.append(parsed_call) + except (json.JSONDecodeError, partial_json_parser.core. + exceptions.MalformedJSON): + continue + + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > self.current_tool_id >= 0 else {} + + if len(tool_call_arr) == 0: + return None + + # Starting a new tool in the array + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # Handle any missed arguments from previous tool + if self.current_tool_id >= 0 and self.current_tool_id < len( + self.prev_tool_call_arr): + prev_tool_call = self.prev_tool_call_arr[ + self.current_tool_id] + diff_arguments = prev_tool_call.get("arguments") + + if diff_arguments: + diff_arguments_json = json.dumps(diff_arguments, + ensure_ascii=False) + already_streamed = self.streamed_args_for_tool[ + self. + current_tool_id] if self.current_tool_id < len( + self.streamed_args_for_tool) else "" + + if diff_arguments_json != already_streamed: + diff = diff_arguments_json[len(already_streamed):] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + if self.current_tool_id < len( + self.streamed_args_for_tool): + self.streamed_args_for_tool[ + self.current_tool_id] = diff_arguments_json + else: + delta = None + else: + delta = None + else: + delta = None + + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # Send tool name if not sent yet + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # Stream arguments + else: + prev_arguments = None + if (self.current_tool_id < len(self.prev_tool_call_arr) + and self.prev_tool_call_arr[self.current_tool_id]): + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + cur_arguments = current_tool_call.get("arguments") + + if not cur_arguments and not prev_arguments: + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "Arguments reset mid-call, skipping streaming") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) + logger.debug("First tokens in arguments received: %s", + cur_arguments_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments_json). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments_json + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + already_streamed = self.streamed_args_for_tool[ + self.current_tool_id] if self.current_tool_id < len( + self.streamed_args_for_tool) else "" + + if cur_args_json.startswith(already_streamed): + argument_diff = cur_args_json[len(already_streamed):] + elif cur_args_json != already_streamed: + argument_diff = cur_args_json + self.streamed_args_for_tool[self.current_tool_id] = "" + else: + argument_diff = "" + + if argument_diff: + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("An unexpected error occurred", + "during streaming tool call handling.") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000..c0691f1 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from random import choices +from string import ascii_letters, digits +from typing import Union + +import partial_json_parser +import regex as re +from partial_json_parser.core.options import Allow +from pydantic import Field + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + +ALPHANUMERIC = ascii_letters + digits + + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + + @staticmethod + def is_valid_id(id: str) -> bool: + return id.isalnum() and len(id) == 9 + + +def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: + return isinstance(model_tokenizer, MistralTokenizer) \ + and model_tokenizer.version >= 11 + + +@ToolParserManager.register_module("mistral") +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with + - [`mistral_common`](https://github.com/mistralai/mistral-common/) + - the examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not isinstance(self.model_tokenizer, MistralTokenizer): + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.vocab.get(self.bot_token) + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + if _is_fn_name_regex_support(self.model_tokenizer): + self.fn_name_regex = re.compile( + r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) + else: + self.fn_name_regex = None + + if self.bot_token_id is None: + raise RuntimeError( + "Mistral Tool Parser could not locate the tool call token in " + "the tokenizer!") + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if not isinstance( + self.model_tokenizer, MistralTokenizer + ) and request.tools and request.tool_choice != 'none': + # Do not skip special tokens when using chat template + # with Mistral parser as TOOL_CALL token is needed + # for tool detection. + # Note: we don't want skip_special_tokens=False + # with MistralTokenizer as it is incompatible + request.skip_special_tokens = False + return request + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + # first remove the BOT token + tool_content = model_output.replace(self.bot_token, "").strip() + + try: + # we first try to directly load the json as parsing very nested + # jsons is difficult + try: + if self.fn_name_regex: + matches = self.fn_name_regex.findall(tool_content) + + function_call_arr = [] + for match in matches: + fn_name = match[0] + args = match[1] + + # fn_name is encoded outside serialized json dump + # only arguments are serialized + function_call_arr.append({ + "name": fn_name, + "arguments": json.loads(args) + }) + else: + function_call_arr = json.loads(tool_content) + except json.JSONDecodeError: + # use a regex to find the part corresponding to the tool call. + # NOTE: This use case should not happen if the model is trained + # correctly. It's a easy possible fix so it's included, but + # can be brittle for very complex / highly nested tool calls + raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + function_call_arr = json.loads(raw_tool_call) + + # Tool Call + tool_calls: list[MistralToolCall] = [ + MistralToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"], + ensure_ascii=False))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception: + logger.exception("Error in extracting tool call from response.") + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=tool_content) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[-1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: list[dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff, ensure_ascii=False).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + if ('"}' in new_text): + new_text = new_text[:new_text.rindex('"}')] + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False)[:-2] + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + if (new_text not in cur_arguments_json): + return None + arguments_delta = cur_arguments_json[:cur_arguments_json. + rindex(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py new file mode 100644 index 0000000..5501028 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Any, Optional + +import regex as re +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("phi4_mini_json") +class Phi4MiniJsonToolParser(ToolParser): + """ + Tool call parser for phi-4-mini models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json + are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: list[dict[str, Any]] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token: str = "functools" + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + logger.debug("Model output: %s", model_output) + + pattern = r'functools\[(.*?)\]' + matches = re.search(pattern, model_output, re.DOTALL) + + if not matches: + logger.debug("No function calls found") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_call_arr: list[dict[str, Any]] = [] + try: + json_content = '[' + matches.group(1) + ']' + + function_call_arr = json.loads(json_content) + logger.debug("Successfully extracted %d function calls", + len(function_call_arr)) + except json.JSONDecodeError as e: + logger.error( + "Failed to parse function calls from model output. " + "Error: %s", str(e)) + + tool_calls: list[ToolCall] = [ + ToolCall( + id=random_tool_call_id(), + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps( + raw_function_call["arguments"] + if "arguments" in raw_function_call else + raw_function_call["parameters"], + ensure_ascii=False), + )) for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Optional[DeltaMessage]: + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py new file mode 100644 index 0000000..73329cd --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast +import json +from collections.abc import Sequence +from typing import Any, Union + +import regex as re +from transformers import PreTrainedTokenizerBase + +import vllm.envs as envs +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("pythonic") +class PythonicToolParser(ToolParser): + """ + Tool call parser for models that produce tool calls in a pythonic style, + such as Llama 3.2 and Llama 4 models. + + Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + is_tool_call_pattern = False + try: + is_tool_call_pattern = self.TOOL_CALL_REGEX.match( + model_output, + timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + except TimeoutError: + logger.warning( + "Regex timeout occurred when matching tool call pattern.") + logger.debug("Regex timeout occurred when matching user input: %s", + model_output) + + if not is_tool_call_pattern: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("["): + return DeltaMessage(content=delta_text) + + try: + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments, + ensure_ascii=False)), + ) + + +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py new file mode 100644 index 0000000..fbede06 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -0,0 +1,296 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["step3"]) +class Step3ToolParser(ToolParser): + """ + Tool parser for a model that uses a specific XML-like format for tool calls. + This version uses a robust, stateful, cursor-based streaming parser and + consolidates tool arguments into a single message. + """ + + TOOL_CALLS_BEGIN = "<|tool_calls_begin|>" + TOOL_CALLS_END = "<|tool_calls_end|>" + TOOL_CALL_BEGIN = "<|tool_call_begin|>" + TOOL_CALL_END = "<|tool_call_end|>" + TOOL_SEP = "<|tool_sep|>" + SPECIAL_TOKENS = [ + TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END + ] + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + # Explicit state flags for robust streaming + self.tool_block_started = False + self.tool_block_finished = False + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + request.skip_special_tokens = False + return request + + @staticmethod + def _parse_steptml_invoke( + action_text: str + ) -> tuple[Optional[str], Optional[dict[str, str]]]: + func_name_match = re.search(r'', + action_text) + if not func_name_match: + return None, None + func_name = func_name_match.group(1) + + params: dict[str, str] = {} + param_matches = re.findall( + r'([^<]*)', + action_text) + for name, value in param_matches: + params[name] = value.strip() + return func_name, params + + def _cast_arguments( + self, + func_name: str, + params: dict[str, Any], + request: ChatCompletionRequest, + ) -> dict[str, Any]: + for tool in request.tools or []: + if tool.function.name == func_name: + schema = tool.function.parameters or {} + properties = schema.get("properties", {}) + for key, value in params.items(): + if not isinstance(value, str): + continue + prop = properties.get(key, {}) + typ = prop.get("type") + if typ == "string": + params[key] = value.strip() + elif typ == "integer": + with contextlib.suppress(ValueError): + params[key] = int(value) + elif typ == "number": + with contextlib.suppress(ValueError): + params[key] = float(value) + elif typ == "boolean": + lower_val = value.lower() + params[key] = lower_val == "true" if lower_val in ( + "true", "false") else value + elif typ == "null": + params[key] = None if value.lower( + ) == "null" else value + break + return params + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # The main loop processes the stream from the last known position. + while True: + if self.position >= len(current_text): + return None # We've processed the entire stream. + + unprocessed_text = current_text[self.position:] + + # STATE: After all tools are done, all subsequent text is content. + if self.tool_block_finished: + self.position = len(current_text) + return DeltaMessage(content=unprocessed_text) + + # STATE: Before the tool block has started. + if not self.tool_block_started: + if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN): + self.position += len(self.TOOL_CALLS_BEGIN) + self.tool_block_started = True + continue # Token consumed, re-loop. + + start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN) + if start_pos == -1: + if self.TOOL_CALLS_BEGIN.startswith( + unprocessed_text.strip()) and unprocessed_text: + return None # It's a prefix, wait. + self.position = len(current_text) + return DeltaMessage(content=unprocessed_text) + else: + content = unprocessed_text[:start_pos] + self.position += len(content) + return DeltaMessage(content=content) + + # STATE: Inside the main tool block. + offset = len(unprocessed_text) - len(unprocessed_text.lstrip()) + unprocessed_text = unprocessed_text.lstrip() + self.position += offset + + if unprocessed_text.startswith(self.TOOL_CALLS_END): + self.position += len(self.TOOL_CALLS_END) + self.tool_block_finished = True + self.current_tool_id = -1 + continue + + # Check if we are between tool calls. + tool_finished = ( + self.current_tool_id != -1 and + self.prev_tool_call_arr[self.current_tool_id].get("finished")) + if self.current_tool_id == -1 or tool_finished: + if unprocessed_text.startswith(self.TOOL_CALL_BEGIN): + self.position += len(self.TOOL_CALL_BEGIN) + if self.current_tool_id == -1: + self.current_tool_id = 0 + else: + self.current_tool_id += 1 + self.current_tool_name_sent = False + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + self.prev_tool_call_arr[ + self.current_tool_id]["finished"] = False + continue + + if self.TOOL_CALL_BEGIN.startswith(unprocessed_text): + return None + + # STATE: Parsing an active tool call. + if self.current_tool_id != -1 and not self.prev_tool_call_arr[ + self.current_tool_id].get("finished", False): + end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END) + if end_tool_pos == -1: + tool_body = unprocessed_text + else: + tool_body = unprocessed_text[:end_tool_pos] + + if end_tool_pos == -1 and self.TOOL_CALL_END.startswith( + tool_body): + return None + + function_name, arguments = self._parse_steptml_invoke( + tool_body) + if not function_name: + return None + + tool_call_arr = { + "name": function_name, + "parameters": arguments or {} + } + + # Send the function name as soon as it's parsed. + if not self.current_tool_name_sent: + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id].update( + tool_call_arr) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name)) + ]) + + # Update our internal state with the latest parsed arguments. + self.prev_tool_call_arr[ + self.current_tool_id].update( # noqa: E501 + tool_call_arr) + + # Only send arguments when the tool call is complete. + if end_tool_pos != -1: + self.position += end_tool_pos + len(self.TOOL_CALL_END) + self.prev_tool_call_arr[ + self.current_tool_id]["finished"] = True + + final_args = self._cast_arguments( + function_name, + tool_call_arr.get("parameters", {}), # type: ignore + request) + if final_args: + final_args_json = json.dumps(final_args, + ensure_ascii=False) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=final_args_json)) + ]) + + # If tool is not finished, return None to wait for more tokens. + return None + + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self.TOOL_CALLS_BEGIN not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1) + if self.TOOL_CALLS_END not in rest: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1) + content = (pre_text + post_text).strip() + + tool_calls: list[ToolCall] = [] + call_parts = tool_block.split(self.TOOL_CALL_BEGIN) + + for part in call_parts: + if not part or self.TOOL_CALL_END not in part: + continue + + call_content = part.split(self.TOOL_CALL_END, 1)[0] + if self.TOOL_SEP not in call_content: + continue + + type_part, invoke_part = call_content.split(self.TOOL_SEP, 1) + if type_part.strip() != "function": + continue + + function_name, params_dict = self._parse_steptml_invoke( + invoke_part) + + if function_name and params_dict is not None: + params_dict = self._cast_arguments(function_name, params_dict, + request) + params_str = json.dumps(params_dict, ensure_ascii=False) + tool_calls.append( + ToolCall(function=FunctionCall(name=function_name, + arguments=params_str))) + if tool_calls: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000..aa41cd6 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from json import JSONDecodeError, JSONDecoder +from typing import Any + +import partial_json_parser +from partial_json_parser.core.options import Allow + + +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string: str, substring: str) -> list[int]: + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices + + +# partial_json_parser doesn't support extra data and +# JSONDecoder.raw_decode doesn't support partial JSON +def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +def consume_space(i: int, s: str) -> int: + while i < len(s) and s[i].isspace(): + i += 1 + return i diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py new file mode 100644 index 0000000..321718b --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.chat_utils import random_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("xlam") +class xLAMToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # Initialize state for streaming mode + self.prev_tool_calls: list[dict] = [] + self.current_tool_id = -1 + self.current_tool_name_sent = False + self.streamed_args: list[str] = [ + ] # Track arguments sent for each tool + + # For backward compatibility with tests + self.current_tools_sent: list[bool] = [] + + # For backward compatibility with serving code + self.prev_tool_call_arr = [] + + # Regex patterns for preprocessing + self.json_code_block_patterns = [ + r"```(?:json)?\s*([\s\S]*?)```", + r"\[TOOL_CALLS\]([\s\S]*?)(?=\n|$)", + r"([\s\S]*?)", + ] + self.thinking_tag_pattern = r"([\s\S]*)" + + # Define streaming state type to be initialized later + self.streaming_state: dict[str, Any] = { + "current_tool_index": -1, + "tool_ids": [], + "sent_tools": [], + } + + def preprocess_model_output( + self, model_output: str) -> tuple[Optional[str], Optional[str]]: + """ + Preprocess the model output to extract content and potential tool calls. + Returns: + Tuple of (content, potential_tool_calls_json) + """ + # Check for thinking tag + thinking_match = re.search(self.thinking_tag_pattern, model_output) + if thinking_match: + content = model_output[:thinking_match.start() + + len("")].strip() + thinking_content = thinking_match.group(1).strip() + + # Try to parse the thinking content as JSON + try: + json.loads(thinking_content) + return content, thinking_content + except json.JSONDecodeError: + # If can't parse as JSON, look for JSON code blocks + for json_pattern in self.json_code_block_patterns: + json_matches = re.findall(json_pattern, thinking_content) + if json_matches: + for json_str in json_matches: + try: + json.loads(json_str) + return content, json_str + except json.JSONDecodeError: + continue + + # Check for JSON code blocks in the entire output + for json_pattern in self.json_code_block_patterns: + json_matches = re.findall(json_pattern, model_output) + if json_matches: + for json_str in json_matches: + try: + json.loads(json_str) + # Extract content by removing the JSON code block + content = re.sub(json_pattern, "", + model_output).strip() + return content, json_str + except json.JSONDecodeError: + continue + + # If the entire output is a valid JSON array or looks like one, treat it as tool calls + if model_output.strip().startswith("["): + try: + json.loads(model_output) + return None, model_output + except json.JSONDecodeError: + # Even if it's not valid JSON yet, it might be a tool call in progress + if ("{" in model_output and "name" in model_output + and "arguments" in model_output): + return None, model_output + + # If no tool calls found, return the original output as content + return model_output, None + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract tool calls from a complete model output. + """ + try: + # Preprocess the model output + content, potential_tool_calls = self.preprocess_model_output( + model_output) + + if not potential_tool_calls: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=content) + + # Parse the potential tool calls as JSON + tool_calls_data = json.loads(potential_tool_calls) + + # Ensure it's an array + if not isinstance(tool_calls_data, list): + logger.debug("Tool calls data is not an array") + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=content or model_output, + ) + + tool_calls: list[ToolCall] = [] + + for idx, call in enumerate(tool_calls_data): + if (not isinstance(call, dict) or "name" not in call + or "arguments" not in call): + logger.debug("Invalid tool call format at index %d", idx) + continue + + tool_call = ToolCall( + id=f"call_{idx}_{random_uuid()}", + type="function", + function=FunctionCall( + name=call["name"], + arguments=(json.dumps(call["arguments"]) if isinstance( + call["arguments"], dict) else call["arguments"]), + ), + ) + tool_calls.append(tool_call) + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content, + ) + + except Exception as e: + logger.exception("Error extracting tool calls: %s", str(e)) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Extract tool calls for streaming mode. + """ + # Simplify detection: if it begins with "[" treat it as a function call + is_function_call = (current_text.strip().startswith("[")) + + # If not a function call, return normal content + if not is_function_call: + return DeltaMessage(content=delta_text) + + try: + # Initialize streaming state if not exists + if not hasattr(self, "streaming_state"): + self.streaming_state = { + "current_tool_index": -1, + "tool_ids": [], + "sent_tools": [], # Track complete state of each tool + } + + # Try parsing as JSON to check for complete tool calls + try: + parsed_tools = json.loads(current_text) + if isinstance(parsed_tools, list): + # Update our tool array for next time + self.prev_tool_call_arr = parsed_tools + except json.JSONDecodeError: + # Not complete JSON yet, use regex for partial parsing + pass + + # Check for test-specific state setup (current_tools_sent) + # This handles the case where tests manually set current_tools_sent + if (hasattr(self, "current_tools_sent") # type: ignore + and len(self.current_tools_sent) > 0): + # If current_tools_sent is set to [False], it means the test wants us to send the name + if (len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False): + # Extract the function name using regex + name_pattern = r'"name"\s*:\s*"([^"]+)"' + name_match = re.search(name_pattern, current_text) + if name_match: + function_name = name_match.group(1) + + # The test expects us to send just the name first + tool_id = random_tool_call_id() + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), # type: ignore + ) + ]) + # Update state to reflect that we've sent the name + self.current_tools_sent = [True] + self.current_tool_id = 0 + self.streaming_state["current_tool_index"] = 0 + if len(self.streaming_state["sent_tools"]) == 0: + self.streaming_state["sent_tools"].append({ + "sent_name": + True, + "sent_arguments_prefix": + False, + "sent_arguments": + "", + }) + else: + self.streaming_state["sent_tools"][0][ + "sent_name"] = True + self.current_tool_name_sent = True + return delta + + # Use regex to identify tool calls in the output + name_pattern = r'"name"\s*:\s*"([^"]+)"' + name_matches = list(re.finditer(name_pattern, current_text)) + tool_count = len(name_matches) + + # If no tools found yet, return + if tool_count == 0: + return None + + # Ensure our state arrays are large enough + while len(self.streaming_state["sent_tools"]) < tool_count: + self.streaming_state["sent_tools"].append({ + "sent_name": + False, + "sent_arguments_prefix": + False, + "sent_arguments": + "", + }) + + while len(self.streaming_state["tool_ids"]) < tool_count: + self.streaming_state["tool_ids"].append(None) + + # Determine if we need to move to a new tool + current_idx = self.streaming_state["current_tool_index"] + + # If we haven't processed any tool yet or current tool is complete, move to next + if current_idx == -1 or current_idx < tool_count - 1: + next_idx = current_idx + 1 + + # If tool at next_idx has not been sent yet + if (next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx] + ["sent_name"]): + # Update indexes + self.streaming_state["current_tool_index"] = next_idx + self.current_tool_id = ( + next_idx # For backward compatibility + ) + current_idx = next_idx + + # Extract the tool name + tool_name = name_matches[current_idx].group(1) + + # Generate ID and send tool name + tool_id = f"call_{current_idx}_{random_uuid()}" + self.streaming_state["tool_ids"][current_idx] = tool_id + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=tool_name).model_dump( + exclude_none=True), # type: ignore + ) + ]) + self.streaming_state["sent_tools"][current_idx][ + "sent_name"] = True + self.current_tool_name_sent = ( + True # For backward compatibility + ) + + # Keep track of streamed args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + + return delta + + # Process arguments for the current tool + if current_idx >= 0 and current_idx < tool_count: + # Support both regular and empty argument objects + # First, check for the empty arguments case: "arguments": {} + empty_args_pattern = ( + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + empty_args_match = re.search(empty_args_pattern, current_text) + + # Check if this tool has empty arguments + if empty_args_match and empty_args_match.start() > 0: + # Find which tool this empty arguments belongs to + empty_args_tool_idx = 0 + for i in range(tool_count): + if i == current_idx: + # If this is our current tool and it has empty arguments + if not self.streaming_state["sent_tools"][ + current_idx]["sent_arguments_prefix"]: + # Send empty object + self.streaming_state["sent_tools"][ + current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][ + current_idx]["sent_arguments"] = "{}" + + # Update streamed_args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{}" + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}"). + model_dump( + exclude_none=True), # type: ignore + ) + ]) + + # Move to next tool if available + if current_idx < tool_count - 1: + self.streaming_state[ + "current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] + + return delta + + # Extract arguments for current tool using regex for non-empty arguments + args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' + args_matches = list(re.finditer(args_pattern, current_text)) + + if current_idx < len(args_matches): + args_text = args_matches[current_idx].group(1) + + # Handle transition between tools + is_last_tool = current_idx == tool_count - 1 + + # Find where the arguments for our current tool end + if not is_last_tool: + # If we have more tools after this one, try to find the complete argument block + next_tool_pos = current_text.find( + "},{", args_matches[current_idx].start()) + if next_tool_pos != -1: + args_end_pos = (next_tool_pos + 1 + ) # +1 to include the '}' + args_text = (current_text[args_matches[current_idx] + .start():args_end_pos]. + split('"arguments":')[1].strip()) + + # If arguments haven't been sent yet + sent_args = self.streaming_state["sent_tools"][ + current_idx]["sent_arguments"] + + # If we haven't sent the opening bracket yet + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] and args_text.startswith( + "{"): + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = "{" + + # Update streamed_args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{" + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{").model_dump( + exclude_none=True), # type: ignore + ) + ]) + return delta + + # If we need to send more arguments + if args_text.startswith(sent_args): + # Calculate what part of arguments we need to send + args_diff = args_text[len(sent_args):] + + if args_diff: + # Update our state + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = args_text + + # Update streamed_args for backward compatibility + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += args_diff + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff).model_dump( + exclude_none=True), # type: ignore + ) + ]) + return delta + + # If the tool's arguments are complete, check if we need to move to the next tool + if args_text.endswith("}") and args_text == sent_args: + # This tool is complete, move to the next one in the next iteration + if current_idx < tool_count - 1: + self.streaming_state["current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] # For compatibility + + # If we got here, we couldn't determine what to stream next + return None + + except Exception as e: + logger.exception(f"Error in streaming tool calls: {e}") + # If we encounter an error, just return the delta text as regular content + return DeltaMessage(content=delta_text) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py new file mode 100644 index 0000000..c4e044f --- /dev/null +++ b/vllm/entrypoints/score_utils.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union + +from torch.nn import CosineSimilarity + +from vllm.outputs import PoolingRequestOutput +from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer, + PreTrainedTokenizerFast) + + +def _cosine_similarity( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + embed_1: list[PoolingRequestOutput], + embed_2: list[PoolingRequestOutput], +) -> list[PoolingRequestOutput]: + + scorer = CosineSimilarity(0) + scores: Union[list[PoolingRequestOutput]] = [] + + for emb_1, emb_2 in zip(embed_1, embed_2): + pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) + + padding = [] + if (pad_token_id := getattr(tokenizer, "pad_token_id", + None)) is not None: + padding = [pad_token_id] + + tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id=f"{emb_1.request_id}_{emb_2.request_id}", + outputs=pair_score, + prompt_token_ids=tokens, + finished=True)) + + return scores + + +def _validate_score_input_lens( + texts_1: Union[list[str], list[dict]], + texts_2: Union[list[str], list[dict]], +): + if len(texts_1) > 1 and len(texts_1) != len(texts_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(texts_1) == 0: + raise ValueError("At least one text element must be given") + if len(texts_2) == 0: + raise ValueError("At least one text_pair element must be given") \ No newline at end of file diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py new file mode 100644 index 0000000..e3646a6 --- /dev/null +++ b/vllm/entrypoints/ssl.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from ssl import SSLContext +from typing import Callable, Optional + +from watchfiles import Change, awatch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SSLCertRefresher: + """A class that monitors SSL certificate files and + reloads them when they change. + """ + + def __init__(self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None) -> None: + self.ssl = ssl_context + self.key_path = key_path + self.cert_path = cert_path + self.ca_path = ca_path + + # Setup certification chain watcher + def update_ssl_cert_chain(change: Change, file_path: str) -> None: + logger.info("Reloading SSL certificate chain") + assert self.key_path and self.cert_path + self.ssl.load_cert_chain(self.cert_path, self.key_path) + + self.watch_ssl_cert_task = None + if self.key_path and self.cert_path: + self.watch_ssl_cert_task = asyncio.create_task( + self._watch_files([self.key_path, self.cert_path], + update_ssl_cert_chain)) + + # Setup CA files watcher + def update_ssl_ca(change: Change, file_path: str) -> None: + logger.info("Reloading SSL CA certificates") + assert self.ca_path + self.ssl.load_verify_locations(self.ca_path) + + self.watch_ssl_ca_task = None + if self.ca_path: + self.watch_ssl_ca_task = asyncio.create_task( + self._watch_files([self.ca_path], update_ssl_ca)) + + async def _watch_files(self, paths, fun: Callable[[Change, str], + None]) -> None: + """Watch multiple file paths asynchronously.""" + logger.info("SSLCertRefresher monitors files: %s", paths) + async for changes in awatch(*paths): + try: + for change, file_path in changes: + logger.info("File change detected: %s - %s", change.name, + file_path) + fun(change, file_path) + except Exception as e: + logger.error( + "SSLCertRefresher failed taking action on file change. " + "Error: %s", e) + + def stop(self) -> None: + """Stop watching files.""" + if self.watch_ssl_cert_task: + self.watch_ssl_cert_task.cancel() + self.watch_ssl_cert_task = None + if self.watch_ssl_ca_task: + self.watch_ssl_ca_task.cancel() + self.watch_ssl_ca_task = None diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py new file mode 100644 index 0000000..423b99d --- /dev/null +++ b/vllm/entrypoints/utils.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import asyncio +import functools +import os +import sys +from typing import Any, Optional, Union + +from fastapi import Request +from fastapi.responses import JSONResponse, StreamingResponse +from starlette.background import BackgroundTask, BackgroundTasks + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +VLLM_SUBCMD_PARSER_EPILOG = ( + "Tip: Use `vllm [serve|run-batch|bench ] " + "--help=` to explore arguments from help.\n" + " - To view a argument group: --help=ModelConfig\n" + " - To view a single argument: --help=max-num-seqs\n" + " - To search by keyword: --help=max\n" + " - To list all groups: --help=listgroup") + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + if request.app.state.enable_server_load_tracking: + # on timeout/cancellation the BackgroundTask in load_aware_call + # cannot decrement the server load metrics. + # Must be decremented by with_cancellation instead. + request.app.state.server_load_metrics -= 1 + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper + + +def decrement_server_load(request: Request): + request.app.state.server_load_metrics -= 1 + + +def load_aware_call(func): + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + raw_request = kwargs.get("raw_request", + args[1] if len(args) > 1 else None) + + if raw_request is None: + raise ValueError( + "raw_request required when server load tracking is enabled") + + if not raw_request.app.state.enable_server_load_tracking: + return await func(*args, **kwargs) + + raw_request.app.state.server_load_metrics += 1 + try: + response = await func(*args, **kwargs) + except Exception: + raw_request.app.state.server_load_metrics -= 1 + raise + + if isinstance(response, (JSONResponse, StreamingResponse)): + if response.background is None: + response.background = BackgroundTask(decrement_server_load, + raw_request) + elif isinstance(response.background, BackgroundTasks): + response.background.add_task(decrement_server_load, + raw_request) + elif isinstance(response.background, BackgroundTask): + # Convert the single BackgroundTask to BackgroundTasks + # and chain the decrement_server_load task to it + tasks = BackgroundTasks() + tasks.add_task(response.background.func, + *response.background.args, + **response.background.kwargs) + tasks.add_task(decrement_server_load, raw_request) + response.background = tasks + else: + raw_request.app.state.server_load_metrics -= 1 + + return response + + return wrapper + + +def cli_env_setup(): + # The safest multiprocessing method is `spawn`, as the default `fork` method + # is not compatible with some accelerators. The default method will be + # changing in future versions of Python, so we should use it explicitly when + # possible. + # + # We only set it here in the CLI entrypoint, because changing to `spawn` + # could break some existing code using vLLM as a library. `spawn` will cause + # unexpected behavior if the code is not protected by + # `if __name__ == "__main__":`. + # + # References: + # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors + # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders + if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: + logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def _validate_truncation_size( + max_model_len: int, + truncate_prompt_tokens: Optional[int], + tokenization_kwargs: Optional[dict[str, Any]] = None, +) -> Optional[int]: + + if truncate_prompt_tokens is not None: + if truncate_prompt_tokens <= -1: + truncate_prompt_tokens = max_model_len + + if truncate_prompt_tokens > max_model_len: + raise ValueError( + f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " + f"is greater than max_model_len ({max_model_len})." + f" Please, select a smaller truncation size.") + + if tokenization_kwargs is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + else: + if tokenization_kwargs is not None: + tokenization_kwargs["truncation"] = False + + return truncate_prompt_tokens + + +def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, + subcommand_name: list[str]): + + # Only handle --help= for the current subcommand. + # Since subparser_init() runs for all subcommands during CLI setup, + # we skip processing if the subcommand name is not in sys.argv. + # sys.argv[0] is the program name. The subcommand follows. + # e.g., for `vllm bench latency`, + # sys.argv is `['vllm', 'bench', 'latency', ...]` + # and subcommand_name is "bench latency". + if len(sys.argv) <= len(subcommand_name) or sys.argv[ + 1:1 + len(subcommand_name)] != subcommand_name: + return + + for arg in sys.argv: + if arg.startswith('--help='): + search_keyword = arg.split('=', 1)[1] + + # List available groups + if search_keyword == 'listgroup': + print("\nAvailable argument groups:") + for group in parser._action_groups: + if group.title and not group.title.startswith( + "positional arguments"): + print(f" - {group.title}") + if group.description: + print(" " + group.description.strip()) + print() + sys.exit(0) + + # For group search + formatter = parser._get_formatter() + for group in parser._action_groups: + if group.title and group.title.lower() == search_keyword.lower( + ): + formatter.start_section(group.title) + formatter.add_text(group.description) + formatter.add_arguments(group._group_actions) + formatter.end_section() + print(formatter.format_help()) + sys.exit(0) + + # For single arg + matched_actions = [] + + for group in parser._action_groups: + for action in group._group_actions: + # search option name + if any(search_keyword.lower() in opt.lower() + for opt in action.option_strings): + matched_actions.append(action) + + if matched_actions: + print(f"\nParameters matching '{search_keyword}':\n") + formatter = parser._get_formatter() + formatter.add_arguments(matched_actions) + print(formatter.format_help()) + sys.exit(0) + + print(f"\nNo group or parameter matching '{search_keyword}'") + print("Tip: use `--help=listgroup` to view all groups.") + sys.exit(1) + + +def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest, + CompletionRequest], + input_length: int, default_sampling_params: dict) -> int: + + max_tokens = getattr(request, "max_completion_tokens", + None) or request.max_tokens + default_max_tokens = max_model_len - input_length + max_output_tokens = current_platform.get_max_output_tokens(input_length) + + return min(val + for val in (default_max_tokens, max_tokens, max_output_tokens, + default_sampling_params.get("max_tokens")) + if val is not None) diff --git a/vllm/env_override.py b/vllm/env_override.py new file mode 100644 index 0000000..9fbdd85 --- /dev/null +++ b/vllm/env_override.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# set some common config/environment variables that should be set +# for all processes created by vllm and all processes +# that interact with vllm workers. +# they are executed whenever `import vllm` is called. + +if os.environ.get('NCCL_CUMEM_ENABLE', '0') != '0': + logger.warning( + "NCCL_CUMEM_ENABLE is set to %s, skipping override. " + "This may increase memory overhead with cudagraph+allreduce: " + "https://github.com/NVIDIA/nccl/issues/1234", + os.environ['NCCL_CUMEM_ENABLE']) +elif not os.path.exists('/dev/nvidia-caps-imex-channels'): + # NCCL requires NCCL_CUMEM_ENABLE to work with + # multi-node NVLink, typically on GB200-NVL72 systems. + # The ultimate way to detect multi-node NVLink is to use + # NVML APIs, which are too expensive to call here. + # As an approximation, we check the existence of + # /dev/nvidia-caps-imex-channels, used by + # multi-node NVLink to communicate across nodes. + # This will still cost some GPU memory, but it is worthwhile + # because we can get very fast cross-node bandwidth with NVLink. + os.environ['NCCL_CUMEM_ENABLE'] = '0' + +# see https://github.com/vllm-project/vllm/pull/15951 +# it avoids unintentional cuda initialization from torch.cuda.is_available() +os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' + +# see https://github.com/vllm-project/vllm/issues/10480 +os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' +# see https://github.com/vllm-project/vllm/issues/10619 +# torch._inductor.config.compile_threads = 1 diff --git a/vllm/envs.py b/vllm/envs.py new file mode 100644 index 0000000..cf17cfc --- /dev/null +++ b/vllm/envs.py @@ -0,0 +1,1184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import os +import sys +import tempfile +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + VLLM_HOST_IP: str = "" + VLLM_PORT: Optional[int] = None + VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() + VLLM_USE_MODELSCOPE: bool = False + VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 + VLLM_NCCL_SO_PATH: Optional[str] = None + LD_LIBRARY_PATH: Optional[str] = None + VLLM_USE_TRITON_FLASH_ATTN: bool = True + VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False + VLLM_FLASH_ATTN_VERSION: Optional[int] = None + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_API_KEY: Optional[str] = None + S3_ACCESS_KEY_ID: Optional[str] = None + S3_SECRET_ACCESS_KEY: Optional[str] = None + S3_ENDPOINT_URL: Optional[str] = None + VLLM_MODEL_REDIRECT_PATH: Optional[str] = None + VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm") + VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") + VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" + VLLM_NO_USAGE_STATS: bool = False + VLLM_DO_NOT_TRACK: bool = False + VLLM_USAGE_SOURCE: str = "" + VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_LEVEL: str = "INFO" + VLLM_LOGGING_PREFIX: str = "" + VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None + VLLM_TRACE_FUNCTION: int = 0 + VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None + VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False + VLLM_PP_LAYER_PARTITION: Optional[str] = None + VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 + VLLM_CPU_MOE_PREPACK: bool = True + VLLM_CPU_SGL_KERNEL: bool = False + VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") + VLLM_XLA_CHECK_RECOMPILATION: bool = False + VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True + VLLM_USE_RAY_SPMD_WORKER: bool = False + VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto" + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False + VLLM_XLA_USE_SPMD: bool = False + VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") + VLLM_IMAGE_FETCH_TIMEOUT: int = 5 + VLLM_VIDEO_FETCH_TIMEOUT: int = 30 + VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_VIDEO_LOADER_BACKEND: str = "opencv" + VLLM_MM_INPUT_CACHE_GIB: int = 8 + VLLM_TARGET_DEVICE: str = "cuda" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_USE_PRECOMPILED: bool = False + VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False + VLLM_NO_DEPRECATION_WARNING: bool = False + VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False + CMAKE_BUILD_TYPE: Optional[str] = None + VERBOSE: bool = False + VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False + VLLM_RPC_TIMEOUT: int = 10000 # ms + VLLM_HTTP_TIMEOUT_KEEP_ALIVE: int = 5 # seconds + VLLM_PLUGINS: Optional[list[str]] = None + VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None + VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False + VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VLLM_TREE_DECODING: bool = False + VLLM_SKIP_P2P_CHECK: bool = False + VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_USE_V1: bool = True + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False + VLLM_ROCM_USE_AITER_LINEAR: bool = True + VLLM_ROCM_USE_AITER_MOE: bool = True + VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_MLA: bool = True + VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_SKINNY_GEMM: bool = True + VLLM_ROCM_FP8_PADDING: bool = True + VLLM_ROCM_MOE_PADDING: bool = True + VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_QUARK_EMU_MEM_OPT: bool = False + VLLM_ENABLE_V1_MULTIPROCESSING: bool = True + VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 + VLLM_DISABLE_COMPILE_CACHE: bool = False + Q_SCALE_CONSTANT: int = 200 + K_SCALE_CONSTANT: int = 200 + V_SCALE_CONSTANT: int = 100 + VLLM_SERVER_DEV_MODE: bool = False + VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 + VLLM_MLA_DISABLE: bool = False + VLLM_RAY_PER_WORKER_GPUS: float = 1.0 + VLLM_RAY_BUNDLE_INDICES: str = "" + VLLM_CUDART_SO_PATH: Optional[str] = None + VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True + VLLM_HPU_USE_DELAYED_SAMPLING: bool = False + VLLM_DP_RANK: int = 0 + VLLM_DP_RANK_LOCAL: int = -1 + VLLM_DP_SIZE: int = 1 + VLLM_DP_MASTER_IP: str = "" + VLLM_DP_MASTER_PORT: int = 0 + VLLM_MOE_DP_CHUNK_SIZE: int = 256 + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False + VLLM_MARLIN_USE_ATOMIC_ADD: bool = False + VLLM_V0_USE_OUTLINES_CACHE: bool = False + VLLM_TPU_BUCKET_PADDING_GAP: int = 0 + VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None + VLLM_USE_DEEP_GEMM: bool = False + VLLM_XGRAMMAR_CACHE_MB: int = 0 + VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 + VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False + VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" + VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 + VLLM_ALL2ALL_BACKEND: str = "naive" + VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 + VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 + VLLM_SLEEP_WHEN_IDLE: bool = False + VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 + VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 + VLLM_KV_CACHE_LAYOUT: Optional[str] = None + VLLM_COMPUTE_NANS_IN_LOGITS: bool = False + VLLM_USE_NVFP4_CT_EMULATIONS: bool = False + VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" + VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True + VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + + # add envs + VLLM_OPTEST_URLS_PORT: Optional[int] = None + VLLM_OPTEST_MODELS_PATH: str = "" + VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False + VLLM_USE_TRITON_OPT_MLA: bool = False + VLLM_USE_FLASH_MLA: bool = False + VLLM_USE_OPT_OP: bool = False + VLLM_USE_TC_PAGED_ATTN: bool = False + VLLM_USE_PA_PRINT_PARAM: bool = False + VLLM_SPEC_DECODE_EAGER: bool = False + VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False + VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16 + VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None + VLLM_HAS_CONTEXT_DEFAULT: bool = False + VLLM_USE_NN: bool = False + VLLM_ENABLE_TBO: bool = False + VLLM_TBO_REQ_DELAY_MS: int = 0 + VLLM_TBO_DECODE_BS: int = 0 + VLLM_TBO_MIN_TOKENS: int = 200 + VLLM_ZERO_OVERHEAD: bool = False + VLLM_ENABLE_MOE_FUSED_GATE: bool = False + VLLM_USE_FLASH_ATTN_PA: bool = False + VLLM_USE_APEX_RN: bool = False + VLLM_USE_GLOBAL_CACHE13: bool = False + VLLM_USE_LIGHT_OP: bool = False + VLLM_USE_TRITON_CAT: bool = False + USE_FUSED_RMS_QUANT: bool = False + VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False + +def get_default_cache_root(): + return os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache"), + ) + + +def get_default_config_root(): + return os.getenv( + "XDG_CONFIG_HOME", + os.path.join(os.path.expanduser("~"), ".config"), + ) + + +def maybe_convert_int(value: Optional[str]) -> Optional[int]: + if value is None: + return None + return int(value) + + +def get_vllm_port() -> Optional[int]: + """Get the port from VLLM_PORT environment variable. + + Returns: + The port number as an integer if VLLM_PORT is set, None otherwise. + + Raises: + ValueError: If VLLM_PORT is a URI, suggest k8s service discovery issue. + """ + if 'VLLM_PORT' not in os.environ: + return None + + port = os.getenv('VLLM_PORT', '0') + + try: + return int(port) + except ValueError as err: + from urllib.parse import urlparse + parsed = urlparse(port) + if parsed.scheme: + raise ValueError( + f"VLLM_PORT '{port}' appears to be a URI. " + "This may be caused by a Kubernetes service discovery issue," + "check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html" + ) from None + raise ValueError( + f"VLLM_PORT '{port}' must be a valid integer") from err + + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# --8<-- [start:env-vars-definition] + +environment_variables: dict[str, Callable[[], Any]] = { + + # ================== Installation Time Env Vars ================== + + # Target device of vLLM, supporting [cuda (by default), + # rocm, neuron, cpu] + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), + + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": + lambda: os.getenv("NVCC_THREADS", None), + + # If set, vllm will use precompiled binaries (*.so) + "VLLM_USE_PRECOMPILED": + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( + os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + + # Whether to force using nightly wheel in python build. + # This is used for testing the nightly wheel in python build. + "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": + lambda: bool(int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) + ), + + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + + # If set, vllm will print verbose logs during installation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + + # Root directory for vLLM configuration files + # Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.path.expanduser( + os.getenv( + "VLLM_CONFIG_ROOT", + os.path.join(get_default_config_root(), "vllm"), + )), + + # ================== Runtime Env Vars ================== + + # Root directory for vLLM cache files + # Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set + "VLLM_CACHE_ROOT": + lambda: os.path.expanduser( + os.getenv( + "VLLM_CACHE_ROOT", + os.path.join(get_default_cache_root(), "vllm"), + )), + + # used in distributed environment to determine the ip address + # of the current node, when the node has multiple network interfaces. + # If you are using multi-node inference, you should set this differently + # on each node. + 'VLLM_HOST_IP': + lambda: os.getenv('VLLM_HOST_IP', ""), + + # used in distributed environment to manually set the communication port + # Note: if VLLM_PORT is set, and some code asks for multiple ports, the + # VLLM_PORT will be used as the first port, and the rest will be generated + # by incrementing the VLLM_PORT value. + 'VLLM_PORT': + get_vllm_port, + + # path used for ipc when the frontend api server is running in + # multi-processing mode to communicate with the backend engine process. + 'VLLM_RPC_BASE_PATH': + lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()), + + # If true, will load models from ModelScope instead of Hugging Face Hub. + # note that the value is true or false, not numbers + "VLLM_USE_MODELSCOPE": + lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", + + # Interval in seconds to log a warning message when the ring buffer is full + "VLLM_RINGBUFFER_WARNING_INTERVAL": + lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), + + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": + lambda: os.environ.get("CUDA_HOME", None), + + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "VLLM_NCCL_SO_PATH": + lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), + + # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": + lambda: os.environ.get("LD_LIBRARY_PATH", None), + + # flag to control if vllm should use triton flash attention + "VLLM_USE_TRITON_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in + ("true", "1")), + + # Use separate prefill and decode kernels for V1 attention instead of + # the unified triton kernel. + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": + lambda: + (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in + ("true", "1")), + + # Force vllm to use a specific flash-attention version (2 or 3), only valid + # when using the flash-attention backend. + "VLLM_FLASH_ATTN_VERSION": + lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)), + + # Internal flag to enable Dynamo fullgraph capture + "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": + lambda: bool( + os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + + # Feature flag to enable/disable Inductor standalone compile. + # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is + # enabled by default. + "VLLM_USE_STANDALONE_COMPILE": + lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1", + + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": + lambda: int(os.environ.get("LOCAL_RANK", "0")), + + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": + lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + + # timeout for each iteration in the engine + "VLLM_ENGINE_ITERATION_TIMEOUT_S": + lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "120")), + + # API key for vLLM API server + "VLLM_API_KEY": + lambda: os.environ.get("VLLM_API_KEY", None), + + # Whether to log responses from API Server for debugging + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": + lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" + ).lower() == "true", + + # S3 access information, used for tensorizer to load model from S3 + "S3_ACCESS_KEY_ID": + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), + "S3_SECRET_ACCESS_KEY": + lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": + lambda: os.environ.get("S3_ENDPOINT_URL", None), + + # Usage stats collection + "VLLM_USAGE_STATS_SERVER": + lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), + "VLLM_NO_USAGE_STATS": + lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DO_NOT_TRACK": + lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( + "DO_NOT_TRACK", None) or "0") == "1", + "VLLM_USAGE_SOURCE": + lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), + + # Logging configuration + # If set to 0, vllm will not configure logging + # If set to 1, vllm will configure logging using the default configuration + # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH + "VLLM_CONFIGURE_LOGGING": + lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": + lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + + # this is used for configuring the default logging level + "VLLM_LOGGING_LEVEL": + lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), + + # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages + "VLLM_LOGGING_PREFIX": + lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), + + # if set, vllm will call logits processors in a thread pool with this many + # threads. This is useful when using custom logits processors that either + # (a) launch additional CUDA kernels or (b) do significant CPU-bound work + # while not holding the python GIL, or both. + "VLLM_LOGITS_PROCESSOR_THREADS": + lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) + if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, + + # Trace function calls + # If set to 1, vllm will trace function calls + # Useful for debugging + "VLLM_TRACE_FUNCTION": + lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), + + # Backend for attention computation + # Available options: + # - "TORCH_SDPA": use torch.nn.MultiheadAttention + # - "FLASH_ATTN": use FlashAttention + # - "XFORMERS": use XFormers + # - "ROCM_FLASH": use ROCmFlashAttention + # - "FLASHINFER": use flashinfer + # - "FLASHMLA": use FlashMLA + "VLLM_ATTENTION_BACKEND": + lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + + # If set, vllm will use flashinfer sampler + "VLLM_USE_FLASHINFER_SAMPLER": + lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) + if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, + + # If set, vllm will force flashinfer to use tensor cores; + # otherwise will use heuristic based on model architecture. + "VLLM_FLASHINFER_FORCE_TENSOR_CORES": + lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))), + + # Pipeline stage partition strategy + "VLLM_PP_LAYER_PARTITION": + lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), + + # (CPU backend only) CPU key-value cache space. + # default is 4 GiB + "VLLM_CPU_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + + # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", + # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. + "VLLM_CPU_OMP_THREADS_BIND": + lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), + + # (CPU backend only) CPU cores not used by OMP threads . + # Those CPU cores will not be used by OMP threads of a rank. + "VLLM_CPU_NUM_OF_RESERVED_CPU": + lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")), + + # (CPU backend only) whether to use prepack for MoE layer. This will be + # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might + # need to set this to "0" (False). + "VLLM_CPU_MOE_PREPACK": + lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), + + # (CPU backend only) whether to use SGL kernels, optimized for small batch. + "VLLM_CPU_SGL_KERNEL": + lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), + + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. + "VLLM_USE_RAY_SPMD_WORKER": + lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), + + # If the env var is set, it uses the Ray's Compiled Graph + # (previously known as ADAG) API which optimizes the + # control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Note that this variable is set to 1 in V1 by default + # when ray distributed executor is used. + "VLLM_USE_RAY_COMPILED_DAG": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), + + # If the env var is set, Ray Compiled Graph uses the specified + # channel type to communicate between workers belonging to + # different pipeline-parallel stages. + # Available options: + # - "auto": use the default channel type + # - "nccl": use NCCL for communication + # - "shm": use shared memory and gRPC for communication + # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": + lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto"), + + # If the env var is set, it enables GPU communication overlap + # (experimental feature) in Ray's Compiled Graph. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) + ), + + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "VLLM_WORKER_MULTIPROC_METHOD": + lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), + + # Path to the cache for storing downloaded assets + "VLLM_ASSETS_CACHE": + lambda: os.path.expanduser( + os.getenv( + "VLLM_ASSETS_CACHE", + os.path.join(get_default_cache_root(), "vllm", "assets"), + )), + + # Timeout for fetching images when serving multimodal models + # Default is 5 seconds + "VLLM_IMAGE_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), + + # Timeout for fetching videos when serving multimodal models + # Default is 30 seconds + "VLLM_VIDEO_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")), + + # Timeout for fetching audio when serving multimodal models + # Default is 10 seconds + "VLLM_AUDIO_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + + # Backend for Video IO + # - "opencv": Default backend that uses OpenCV stream buffered backend. + # + # Custom backend implementations can be registered + # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and + # imported at runtime. + # If a non-existing backend is used, an AssertionError will be thrown. + "VLLM_VIDEO_LOADER_BACKEND": + lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), + + # Cache size (in GiB) for multimodal input cache + # Default is 4 GiB + "VLLM_MM_INPUT_CACHE_GIB": + lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), + + # Path to the XLA persistent cache directory. + # Only used for XLA devices such as TPUs. + "VLLM_XLA_CACHE_PATH": + lambda: os.path.expanduser( + os.getenv( + "VLLM_XLA_CACHE_PATH", + os.path.join(get_default_cache_root(), "vllm", "xla_cache"), + )), + + # If set, assert on XLA recompilation after each execution step. + "VLLM_XLA_CHECK_RECOMPILATION": + lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), + + # Enable SPMD mode for TPU backend. + "VLLM_XLA_USE_SPMD": + lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), + "VLLM_FUSED_MOE_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + # Control whether to use fused MoE activation chunking. Current chunking + # logic is incompatible with torch.compile and causes IMA. See issue + # https://github.com/vllm-project/vllm/issues/19631. + "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": + lambda: bool( + int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))), + + # If set, vllm will skip the deprecation warnings. + "VLLM_NO_DEPRECATION_WARNING": + lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + + # If set, the OpenAI API server will stay alive even after the underlying + # AsyncLLMEngine errors and stops serving requests + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": + lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)), + + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows + # the user to specify a max sequence length greater than + # the max length derived from the model's config.json. + # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": + lambda: + (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in + ("1", "true")), + + # If set, forces FP8 Marlin to be used for FP8 quantization regardless + # of the hardware support for FP8 compute. + "VLLM_TEST_FORCE_FP8_MARLIN": + lambda: + (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in + ("1", "true")), + "VLLM_TEST_FORCE_LOAD_FORMAT": + lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"), + + # Time in ms for the zmq client to wait for a response from the backend + # server for simple data operations + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), + + # Timeout in seconds for keeping HTTP connections alive in API server + "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": + lambda: int(os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5")), + + # a list of plugin names to load, separated by commas. + # if this is not set, it means all plugins will be loaded + # if this is set to an empty string, no plugins will be loaded + "VLLM_PLUGINS": + lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ + "VLLM_PLUGINS"].split(","), + + # a local directory to look in for unrecognized LoRA adapters. + # only works if plugins are enabled and + # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. + "VLLM_LORA_RESOLVER_CACHE_DIR": + lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), + + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "VLLM_TORCH_PROFILER_DIR": + lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os + .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + + # If set, allow loading or unloading lora adapters in runtime, + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": + lambda: + (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in + ("1", "true")), + + # If set, vLLM will use tree-style speculative decoding. + "VLLM_TREE_DECODING": + lambda: + (os.environ.get("VLLM_TREE_DECODING", "0").strip().lower() in + ("1", "true")), + # By default, vLLM will check the peer-to-peer capability itself, + # in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa + # If this env var is set to 1, vLLM will skip the peer-to-peer check, + # and trust the driver's peer-to-peer capability report. + "VLLM_SKIP_P2P_CHECK": + lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1", + + # List of quantization kernels that should be disabled, used for testing + # and performance comparisons. Currently only affects MPLinearKernel + # selection + # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) + "VLLM_DISABLED_KERNELS": + lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ + "VLLM_DISABLED_KERNELS"].split(","), + + # If set, use the V1 code path. + "VLLM_USE_V1": + lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), + + # Disable aiter ops unless specifically enabled. + # Acts as a parent switch to enable the rest of the other operations. + "VLLM_ROCM_USE_AITER": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1")), + + # Whether to use aiter paged attention. + # By default is disabled. + "VLLM_ROCM_USE_AITER_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in + ("true", "1")), + + # use aiter linear op if aiter ops are enabled + # The following list of related ops + # - scaled_mm (per-tensor / rowwise) + "VLLM_ROCM_USE_AITER_LINEAR": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in + ("true", "1")), + + # Whether to use aiter moe ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MOE": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in + ("true", "1")), + + # use aiter rms norm op if aiter ops are enabled. + "VLLM_ROCM_USE_AITER_RMSNORM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in + ("true", "1")), + + # Whether to use aiter mla ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MLA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in + ("true", "1")), + + # Whether to use aiter mha ops. + # By default is enabled. + "VLLM_ROCM_USE_AITER_MHA": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in + ("true", "1")), + + # use rocm skinny gemms + "VLLM_ROCM_USE_SKINNY_GEMM": + lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in + ("true", "1")), + + # Pad the fp8 weights to 256 bytes for ROCm + "VLLM_ROCM_FP8_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), + + # Pad the weights for the moe kernel + "VLLM_ROCM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "0"))), + + # custom paged attention kernel for MI3* cards + "VLLM_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1")), + + # Custom quick allreduce kernel for MI3* cards + # Choice of quantization level: FP, INT8, INT6, INT4 or NONE + # Recommended for large models to get allreduce + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": + lambda: os.getenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE").upper(), + + # Custom quick allreduce kernel for MI3* cards + # Due to the lack of the bfloat16 asm instruction, bfloat16 + # kernels are slower than fp16, + # If environment variable is set to 1, the input is converted to fp16 + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": + lambda: + (os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in + ("true", "1")), + + # Custom quick allreduce kernel for MI3* cards. + # Controls the maximum allowed number of data bytes(MB) for custom quick + # allreduce communication. + # Default: 2048 MB. + # Data exceeding this size will use either custom allreduce or RCCL + # communication. + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": + lambda: maybe_convert_int( + os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), + + # If set, when running in Quark emulation mode, do not dequantize the + # weights at load time. Instead, dequantize weights on-the-fly during + # kernel execution. + # This allows running larger models at the cost of slower inference. + # This flag has no effect when not running in Quark emulation mode. + "VLLM_QUARK_EMU_MEM_OPT": + lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))), + + # Divisor for dynamic query scale factor calculation for FP8 KV Cache + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), + # Divisor for dynamic key scale factor calculation for FP8 KV Cache + "K_SCALE_CONSTANT": + lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + # Divisor for dynamic value scale factor calculation for FP8 KV Cache + "V_SCALE_CONSTANT": + lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + + # If set, enable multiprocessing in LLM for the V1 code path. + "VLLM_ENABLE_V1_MULTIPROCESSING": + lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), + "VLLM_LOG_BATCHSIZE_INTERVAL": + lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), + "VLLM_DISABLE_COMPILE_CACHE": + lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + + # If set, vllm will run in development mode, which will enable + # some additional endpoints for developing and debugging, + # e.g. `/reset_prefix_cache` + "VLLM_SERVER_DEV_MODE": + lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), + + # Controls the maximum number of requests to handle in a + # single asyncio task when processing per-token outputs in the + # V1 AsyncLLM interface. It is applicable when handling a high + # concurrency of streaming requests. + # Setting this too high can result in a higher variance of + # inter-message latencies. Setting it too low can negatively impact + # TTFT and overall throughput. + "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")), + + # If set, vLLM will disable the MLA attention optimizations. + "VLLM_MLA_DISABLE": + lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), + + # Number of GPUs per worker in Ray, if it is set to be a fraction, + # it allows ray to schedule multiple actors on a single GPU, + # so that users can colocate other actors on the same GPUs as vLLM. + "VLLM_RAY_PER_WORKER_GPUS": + lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")), + + # Bundle indices for Ray, if it is set, it can control precisely + # which indices are used for the Ray bundle, for every worker. + # Format: comma-separated list of integers, e.g. "0,1,2,3" + "VLLM_RAY_BUNDLE_INDICES": + lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), + + # In some system, find_loaded_library() may not work. So we allow users to + # specify the path through environment variable VLLM_CUDART_SO_PATH. + "VLLM_CUDART_SO_PATH": + lambda: os.getenv("VLLM_CUDART_SO_PATH", None), + + # Contiguous cache fetching to avoid using costly gather operation on + # Gaudi3. This is only applicable to HPU contiguous cache. If set to true, + # contiguous cache fetch will be used. + "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": + lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in + ("1", "true"), + + # Use delayed sampling for HPU to reduce host cpu overhead + # between each step. + "VLLM_HPU_USE_DELAYED_SAMPLING": + lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in + ("1", "true"), + + # Rank of the process in the data parallel setting + "VLLM_DP_RANK": + lambda: int(os.getenv("VLLM_DP_RANK", "0")), + + # Rank of the process in the data parallel setting. + # Defaults to VLLM_DP_RANK when not set. + "VLLM_DP_RANK_LOCAL": + lambda: int( + os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)), + + # World size of the data parallel setting + "VLLM_DP_SIZE": + lambda: int(os.getenv("VLLM_DP_SIZE", "1")), + + # IP address of the master node in the data parallel setting + "VLLM_DP_MASTER_IP": + lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), + + # Port of the master node in the data parallel setting + "VLLM_DP_MASTER_PORT": + lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), + + # In the context of executing MoE models with Data-Parallel, Expert-Parallel + # and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE + # dictates the quantum of tokens that can be dispatched from a DP + # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE + # units. + "VLLM_MOE_DP_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + + # Randomize inputs during dummy runs when using Data Parallel + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": + lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1", + + # Whether to use S3 path for model loading in CI via RunAI Streamer + "VLLM_CI_USE_S3": + lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", + + # Use model_redirect to redirect the model name to a local folder. + # `model_redirect` can be a json file mapping the model between + # repo_id and local folder: + # {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"} + # or a space separated values table file: + # meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B + "VLLM_MODEL_REDIRECT_PATH": + lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None), + + # Whether to use atomicAdd reduce in gptq/awq marlin kernel. + "VLLM_MARLIN_USE_ATOMIC_ADD": + lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", + + # Whether to turn on the outlines cache for V0 + # This cache is unbounded and on disk, so it's not safe to use in + # an environment with potentially malicious users. + "VLLM_V0_USE_OUTLINES_CACHE": + lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", + + # Gap between padding buckets for the forward pass. So we have + # 8, we will run forward pass with [16, 24, 32, ...]. + "VLLM_TPU_BUCKET_PADDING_GAP": + lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) + if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, + "VLLM_TPU_MOST_MODEL_LEN": + lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)), + + # Allow use of DeepGemm kernels for fused moe ops. + "VLLM_USE_DEEP_GEMM": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + + # Control the cache sized used by the xgrammar compiler. The default + # of 512 MB should be enough for roughly 1000 JSON schemas. + # It can be changed with this variable if needed for some reason. + "VLLM_XGRAMMAR_CACHE_MB": + lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), + + # Control the threshold for msgspec to use 'zero copy' for + # serialization/deserialization of tensors. Tensors below + # this limit will be encoded into the msgpack buffer, and + # tensors above will instead be sent via a separate message. + # While the sending side still actually copies the tensor + # in all cases, on the receiving side, tensors above this + # limit will actually be zero-copy decoded. + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": + lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + + # If set, allow insecure serialization using pickle. + # This is useful for environments where it is deemed safe to use the + # insecure method and it is needed for some reason. + "VLLM_ALLOW_INSECURE_SERIALIZATION": + lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), + + # IP address used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_HOST": + lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"), + + # Port used for NIXL handshake between remote agents. + "VLLM_NIXL_SIDE_CHANNEL_PORT": + lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), + + # all2all backend for vllm's expert parallel communication + # Available options: + # - "naive": naive all2all implementation using all-reduce + # - "pplx": use pplx kernels + # - "deepep_high_throughput", use deepep high-throughput kernels + # - "deepep_low_latency", use deepep low-latency kernels + "VLLM_ALL2ALL_BACKEND": + lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), + + # Control the maximum number of tokens per expert supported by the + # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for + # the blockscale tensor of activations NVFP4 Quantization. + # This is used to prevent the kernel from running out of memory. + "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": + lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), + + # Regex timeout for use by the vLLM tool parsing plugins. + "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": + lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), + + # Reduce CPU usage when vLLM is idle. Enabling this will incur small + # latency penalty when a request eventually comes. + "VLLM_SLEEP_WHEN_IDLE": + lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), + + # Control the max chunk bytes (in MB) for the rpc message queue. + # Object larger than this threshold will be broadcast to worker + # processes via zmq. + "VLLM_MQ_MAX_CHUNK_BYTES_MB": + lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")), + + # Timeout in seconds for execute_model RPC calls in multiprocessing + # executor (only applies when TP > 1). + "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": + lambda: int(os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "300")), + + # KV Cache layout used throughout vllm. + # Some common values are: + # - NHD + # - HND + # Where N=num_blocks, H=num_heads and D=head_size. The default value will + # leave the layout choice to the backend. Mind that backends may only + # implement and support a subset of all possible layouts. + "VLLM_KV_CACHE_LAYOUT": + lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None), + + # Enable checking whether the generated logits contain NaNs, + # indicating corrupted output. Useful for debugging low level bugs + # or bad hardware but it may add compute overhead. + "VLLM_COMPUTE_NANS_IN_LOGITS": + lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), + + # Controls whether or not emulations are used for NVFP4 + # generations on machines < 100 for compressed-tensors + # models + "VLLM_USE_NVFP4_CT_EMULATIONS": + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), + + # used in optest environment to manually set the https port + 'VLLM_OPTEST_URLS_PORT': + lambda: int(os.getenv('VLLM_OPTEST_URLS_PORT', '8000')) + if 'VLLM_OPTEST_URLS_PORT' in os.environ else None, + + # Path to the optest models. + # If set, will load models from local path instead of Hugging Face Hub. + 'VLLM_OPTEST_MODELS_PATH': + lambda: os.getenv('VLLM_OPTEST_MODELS_PATH', "") or os.getenv("OPTEST_MODELS_PATH", ""), + + # flag to control if vllm should use triton prefix flash attention + "VLLM_USE_TRITON_PREFIX_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in + ("true", "1")), + + # If set, vLLM will use optimized MLA attention optimizations. + "VLLM_USE_TRITON_OPT_MLA": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))), + + # If set, vLLM will use FLASH MLA attention optimizations. + "VLLM_USE_FLASH_MLA": + lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), + + # flag to control vllm to use optimized kernels + "VLLM_USE_OPT_OP": + lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in + ("true", "1")), + + # flag to control vllm to use optimized tc paged attn kernels + "VLLM_USE_TC_PAGED_ATTN": + lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in + ("true", "1")), + + # flag to control if vllm print pa parameters + "VLLM_USE_PA_PRINT_PARAM": + lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in + ("true", "1")), + + # If set, vLLM will disable the draft model in cudagraph mode. + "VLLM_SPEC_DECODE_EAGER": + lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))), + + # flag to control vllm to use optimized kernels + "VLLM_PCIE_USE_CUSTOM_ALLREDUCE": + lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "0"))), + + # flag to control vllm to use optimized kernels + "VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX": + lambda: int(os.getenv("VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX", "16")), + + # If set, vLLM will disable the draft model in cudagraph mode. + "VLLM_ENFORCE_EAGER_BS_THRESHOLD": + lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")), + + # 'has_comtext' is a variable in common.py, which is calculated + # by metadata by default. However, it may introduce synchronization + # and affect performance, so it is directly assigned as False. + # If there are any problems during use, use environment variables + # to restore the default usage. + "VLLM_HAS_CONTEXT_DEFAULT": + lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "1"))), + + # If set, vLLM will transpose weight to use nn layout + "VLLM_USE_NN": + lambda: (os.environ.get("VLLM_USE_NN", "True").lower() in + ("true", "1")), + + # Enable two batch overlap. + "VLLM_ENABLE_TBO": + lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))), + + # set delay on server when only one requet, the purpose is to merge a larger batch. + "VLLM_TBO_REQ_DELAY_MS": + lambda: int(os.getenv("VLLM_TBO_REQ_DELAY_MS", "0")), + + # set the minimum batch size to enable TBO in decode, if < 2 , disable TBO in decode. + "VLLM_TBO_DECODE_BS": + lambda: int(os.getenv("VLLM_TBO_DECODE_BS", "0")), + + # set the minimum tokens size for each mini-batch to enable TBO on v1, default is 200. + "VLLM_TBO_MIN_TOKENS": + lambda: int(os.getenv("VLLM_TBO_MIN_TOKENS", "200")), + + # Enable zero overhead scheduler. + "VLLM_ZERO_OVERHEAD": + lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))), + + # If set, vLLM will enable the moe_fused_gate kernel. + "VLLM_ENABLE_MOE_FUSED_GATE": + lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))), + + # vLLM will use FlashAttention Backend for page attention computation on rocm + "VLLM_USE_FLASH_ATTN_PA": + lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in + ("true", "1")), + + # vLLM will use apex for rmsnorm + "VLLM_USE_APEX_RN": + lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in + ("true", "1")), + # vLLM will use global cache for moe + "VLLM_USE_GLOBAL_CACHE13": + lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in + ("true", "1")), + # vLLM will use global cache for moe + "VLLM_USE_LIGHT_OP": + lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "True").lower() in + ("true", "1")), + # vLLM will use global cache for moe + "VLLM_USE_TRITON_CAT": + lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in + ("true", "1")), + # vLLM will use opt merge_aatn_states,not triton + "VLLM_USE_MERGE_ATTN_STATES_OPT": + lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in + ("true", "1")), + # vllm will use rmsquant fused op + "USE_FUSED_RMS_QUANT": + lambda: (os.getenv('USE_FUSED_RMS_QUANT', '0').lower() in + ("true", "1")), +} + +# --8<-- [end:env-vars-definition] + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def is_set(name: str): + """Check if an environment variable is explicitly set.""" + if name in environment_variables: + return name in os.environ + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def set_vllm_use_v1(use_v1: bool): + if is_set("VLLM_USE_V1"): + raise ValueError( + "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " + "explicitly by the user. Please raise this as a Github " + "Issue and explicitly set VLLM_USE_V1=0 or 1.") + os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" + + +def compute_hash() -> str: + """ + WARNING: Whenever a new key is added to this environment + variables, ensure that it is included in the factors list if + it affects the computation graph. For example, different values + of VLLM_PP_LAYER_PARTITION will generate different computation + graphs, so it is included in the factors list. The env vars that + affect the choice of different kernels or attention backends should + also be included in the factors list. + """ + factors: list[Any] = [] + + # summarize environment variables + def factorize(name: str): + if __getattr__(name): + factors.append(__getattr__(name)) + else: + factors.append("None") + + # The values of envs may affects the computation graph. + # TODO(DefTruth): hash all environment variables? + # for key in environment_variables: + # factorize(key) + environment_variables_to_hash = [ + "VLLM_PP_LAYER_PARTITION", + "VLLM_MLA_DISABLE", + "VLLM_USE_TRITON_FLASH_ATTN", + "VLLM_USE_TRITON_AWQ", + "VLLM_DP_RANK", + "VLLM_DP_SIZE", + "VLLM_USE_STANDALONE_COMPILE", + "VLLM_FUSED_MOE_CHUNK_SIZE", + ] + for key in environment_variables_to_hash: + if key in environment_variables: + factorize(key) + + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest() + + return hash_str diff --git a/vllm/executor/__init__.py b/vllm/executor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py new file mode 100644 index 0000000..99e1220 --- /dev/null +++ b/vllm/executor/executor_base.py @@ -0,0 +1,401 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from abc import ABC, abstractmethod +from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, + Union) + +import torch.nn as nn +from typing_extensions import TypeVar + +import vllm.platforms +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.utils import make_async +from vllm.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + +_R = TypeVar("_R", default=Any) + + +class ExecutorBase(ABC): + """Base class for all executors. + + An executor is responsible for executing the model on one device, + or it can be a distributed executor + that can execute the model on multiple devices. + """ + + uses_ray: bool # whether the executor uses Ray for orchestration. + + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self._init_executor() + self.is_sleeping = False + self.sleeping_tags: set[str] = set() + + @abstractmethod + def _init_executor(self) -> None: + raise NotImplementedError + + @abstractmethod + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict[str, Any]] = None) -> List[_R]: + """ + Execute an RPC call on all workers. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + raise NotImplementedError + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + Normally, this should simply delegate to the underlying Worker. Some + ExecutorBase may require modification of the result, e.g. to ensure the + selected cache sizes are compatible with all workers. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + results = self.collective_rpc("determine_num_available_blocks") + a = min([r[0] for r in results]) + b = min([r[1] for r in results]) + return a, b + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 workers. + logger.info("# %s blocks: %d, # CPU blocks: %d", + vllm.platforms.current_platform.device_name, + num_gpu_blocks, num_cpu_blocks) + max_concurrency = (num_gpu_blocks * self.cache_config.block_size / + self.model_config.max_model_len) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, max_concurrency) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.collective_rpc("initialize_cache", + args=(num_gpu_blocks, num_cpu_blocks)) + + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + """ + Run a function directly on the model inside each worker, + returning the result for each of them. + """ + + def rpc_func(worker: WorkerBase) -> _R: + return func(worker.get_model()) + + return self.collective_rpc(rpc_func) + + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + output = self.collective_rpc("execute_model", + args=(execute_model_req, )) + return output[0] + + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("add_lora", args=(lora_request, ))) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("remove_lora", args=(lora_id, ))) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self.collective_rpc("pin_lora", args=(lora_id, ))) + + def list_loras(self) -> Set[int]: + sets = self.collective_rpc("list_loras") + for s in sets: + assert s == sets[0], "All workers should have the same LORAs." + return sets[0] + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("add_prompt_adapter", + args=(prompt_adapter_request, ))) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("remove_prompt_adapter", + args=(prompt_adapter_id, ))) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return all( + self.collective_rpc("pin_prompt_adapter", + args=(prompt_adapter_id, ))) + + def list_prompt_adapters(self) -> Set[int]: + sets = self.collective_rpc("list_prompt_adapters") + for s in sets: + assert (s == sets[0] + ), "All workers should have the same prompt adapters." + return sets[0] + + def start_profile(self) -> None: + self.collective_rpc("start_profile") + + def stop_profile(self) -> None: + self.collective_rpc("stop_profile") + + def sleep(self, level: int = 1): + if self.is_sleeping: + logger.warning("Executor is already sleeping.") + return + time_before_sleep = time.perf_counter() + self.collective_rpc("sleep", kwargs=dict(level=level)) + time_after_sleep = time.perf_counter() + self.sleeping_tags = {"weights", "kv_cache"} + self.is_sleeping = True + logger.info("It took %.6f seconds to fall asleep.", + time_after_sleep - time_before_sleep) + + def wake_up(self, tags: Optional[list[str]] = None): + if not self.is_sleeping: + logger.warning("Executor is not sleeping.") + return + if tags: + for tag in tags: + if tag not in self.sleeping_tags: + logger.warning("Tag %s is not in sleeping tags %s", tag, + self.sleeping_tags) + return + time_before_wakeup = time.perf_counter() + self.collective_rpc("wake_up", kwargs=dict(tags=tags)) + time_after_wakeup = time.perf_counter() + logger.info("It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags) + if tags: + for tag in tags: + self.sleeping_tags.remove(tag) + else: + self.sleeping_tags.clear() + if not self.sleeping_tags: + self.is_sleeping = False + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.collective_rpc("save_sharded_state", + kwargs=dict(path=path, + pattern=pattern, + max_size=max_size)) + + @abstractmethod + def check_health(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError + + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + output = await make_async(self.execute_model)(execute_model_req) + return output + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() + + +class DistributedExecutorBase(ExecutorBase): + """Abstract superclass of distributed executor implementations.""" + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + + super().__init__(*args, **kwargs) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + # TODO: unify into collective_rpc + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_tensor_parallel_workers_only=True) + + # Only the driver worker returns the sampling results. + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model(execute_model_req=None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + @abstractmethod + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution loop + running in each of the remote workers. In this case, this method + returns None. Otherwise, this method returns the model output. + """ + raise NotImplementedError + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + return self._run_workers(method, *args, **(kwargs or {})) + + @abstractmethod + def _run_workers( + self, + method: Union[str, Callable], + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + + # TODO: simplify and merge with collective_rpc + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + raise NotImplementedError + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + @abstractmethod + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py new file mode 100644 index 0000000..52ac472 --- /dev/null +++ b/vllm/executor/mp_distributed_executor.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import os +from typing import Any, Callable, List, Optional, Union + +import cloudpickle + +from vllm.executor.executor_base import DistributedExecutorBase +from vllm.executor.multiproc_worker_utils import ( + ProcessWorkerWrapper, ResultHandler, WorkerMonitor, + set_multiprocessing_worker_envs) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, + get_distributed_init_method, get_ip, get_open_port, + make_async, run_method, update_environment_variables) +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class MultiprocessingDistributedExecutor(DistributedExecutorBase): + """Python multiprocessing-based distributed executor""" + + uses_ray: bool = False + + def _check_cuda(self) -> None: + """Check that the number of GPUs is sufficient for the parallel + configuration. Separate from _init_executor to reduce the number of + indented blocks. + """ + parallel_config = self.parallel_config + world_size = parallel_config.world_size + tensor_parallel_size = parallel_config.tensor_parallel_size + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + if tensor_parallel_size > cuda_device_count: + raise RuntimeError( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + if world_size > cuda_device_count: + raise RuntimeError( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" or "HIP_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + def _init_executor(self) -> None: + + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + self._check_cuda() + + # Create the parallel GPU workers. + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set multiprocessing envs that are common to V0 and V1 + set_multiprocessing_worker_envs(self.parallel_config) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + self.workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[ProcessWorkerWrapper] = [] + + if world_size == 1: + self.worker_monitor = None + else: + result_handler = ResultHandler() + for rank in range(1, world_size): + worker = ProcessWorkerWrapper(result_handler, + WorkerWrapperBase, + self.vllm_config, rank) + self.workers.append(worker) + if rank % tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + self.driver_worker = WorkerWrapperBase(self.vllm_config, 0) + + all_kwargs = [] + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + for i in range(world_size): + local_rank = i + rank = i + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self._run_workers("init_worker", all_kwargs) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + self.driver_exec_model = make_async(self.driver_worker.execute_model) + self.pp_locks: Optional[List[asyncio.Lock]] = None + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model(execute_model_req) + + def _run_workers( + self, + method: Union[str, Callable], + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> List[Any]: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + """ + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if async_run_tensor_parallel_workers_only: + # Run only non-driver workers and just return futures. + return [ + worker.execute_method(sent_method, *args, **kwargs) + for worker in self.non_driver_workers + ] + + # Start all remote workers first. + worker_outputs = [ + worker.execute_method(sent_method, *args, **kwargs) + for worker in self.workers + ] + + driver_worker_output = run_method(self.driver_worker, sent_method, + args, kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if not self.tp_driver_workers: + return await self.driver_exec_model(execute_model_req) + + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], + execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method_async, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py new file mode 100644 index 0000000..852c8f5 --- /dev/null +++ b/vllm/executor/msgspec_utils.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from array import array +from typing import Any, Type + +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE + + +def encode_hook(obj: Any) -> Any: + """Custom msgspec enc hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if isinstance(obj, array): + assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( + f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " + f"Given array has a type code of {obj.typecode}.") + return obj.tobytes() + + +def decode_hook(type: Type, obj: Any) -> Any: + """Custom msgspec dec hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if type is array: + deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) + deserialized.frombytes(obj) + return deserialized diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 0000000..a6c172b --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,313 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import os +import sys +import threading +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import _maybe_force_spawn, get_mp_context, run_method + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if not loop.is_closed(): + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = get_mp_context().Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + if logger: + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[VllmConfig, int], Any], + vllm_config: VllmConfig, rank: int) -> None: + self.mp = get_mp_context() + self._task_queue = self.mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + vllm_config=vllm_config, + rank=rank, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: Union[str, bytes], args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: Union[str, bytes], *args, + **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[VllmConfig, int], Any], + task_queue: Queue, + result_queue: Queue, + vllm_config: VllmConfig, + rank: int, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = get_mp_context().current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory(vllm_config, rank) + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + output = run_method(worker, method, args, kwargs) + except SystemExit: + raise + except KeyboardInterrupt: + break + except BaseException as e: + logger.exception( + "Exception in worker %s while processing method %s.", + process_name, method) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + # Flush TunableOp results when TunableOp is enabled and + # online (in situ) tuning is enabled. + # Offline tuning API (record_untuned_is_enabled()) only + # available in PyTorch 2.6 or later. + if torch.cuda.is_available(): + import torch.cuda.tunable as tunable + if (tunable.is_enabled() and tunable.tuning_is_enabled() + and not tunable.record_untuned_is_enabled()): + tunable.write_file() + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] + + +def set_multiprocessing_worker_envs(parallel_config): + """ Set up environment variables that should be used when there are workers + in a multiprocessing environment. This should be called by the parent + process before worker processes are created""" + + _maybe_force_spawn() + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py new file mode 100644 index 0000000..84e8ddd --- /dev/null +++ b/vllm/executor/ray_distributed_executor.py @@ -0,0 +1,701 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +import os +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import cloudpickle +import msgspec + +import vllm.envs as envs +from vllm.executor.executor_base import ( + DistributedExecutorBase) # yapf: disable +from vllm.executor.msgspec_utils import encode_hook +from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, + ray) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, make_async) + +if ray is not None: + from ray.actor import ActorHandle + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +else: + ActorHandle = None + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +@dataclass +class RayWorkerMetaData: + """ + Metadata for a Ray worker. + The order of ray worker creation can be random, + and we need to reset the rank after creating all workers. + """ + worker: ActorHandle + created_rank: int + adjusted_rank: int = -1 + ip: str = "" + + +class RayDistributedExecutor(DistributedExecutorBase): + """Ray-based distributed executor""" + + # These env vars are worker-specific, therefore are NOT copied + # from the driver to the workers + WORKER_SPECIFIC_ENV_VARS = { + "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + } + + config_home = envs.VLLM_CONFIG_ROOT + # This file contains a list of env vars that should not be copied + # from the driver to the Ray workers. + non_carry_over_env_vars_file = os.path.join( + config_home, "ray_non_carry_over_env_vars.json") + if os.path.exists(non_carry_over_env_vars_file): + with open(non_carry_over_env_vars_file) as f: + non_carry_over_env_vars = set(json.load(f)) + else: + non_carry_over_env_vars = set() + + uses_ray: bool = True + + def _init_executor(self) -> None: + self.forward_dag: Optional[ray.dag.CompiledDAG] = None + if envs.VLLM_USE_V1 and not current_platform.is_xpu(): + # V1 uses SPMD worker and compiled DAG + os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" + os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" + + # For TPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu(): + os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" + + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + + assert self.uses_ray + initialize_ray_cluster(self.parallel_config) + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + self.output_decoder = msgspec.msgpack.Decoder( + Optional[List[SamplerOutput]]) + self.use_v1 = envs.VLLM_USE_V1 + + self.pp_locks: Optional[List[asyncio.Lock]] = None + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) + + def shutdown(self) -> None: + logger.info( + "Shutting down Ray distributed executor. If you see error log " + "from logging.cc regarding SIGTERM received, please ignore because " + "this is the expected termination process in Ray.") + if hasattr(self, "forward_dag") and self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) + self.forward_dag = None + + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + bundle_indices: List[int] + if envs.VLLM_RAY_BUNDLE_INDICES: + # Use the bundle indices specified by the user. + bundle_indices = list( + map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) + assert len(bundle_indices) == self.parallel_config.world_size, \ + ("VLLM_RAY_BUNDLE_INDICES must have the same size" + f" as the world size, but got {bundle_indices=} " + f"and {self.parallel_config.world_size=}") + assert len(set(bundle_indices)) == len(bundle_indices), \ + ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}") + else: + # use the first N bundles that have GPU resources. + bundle_indices = [] + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if bundle.get(current_platform.ray_device_key, 0): + bundle_indices.append(bundle_id) + bundle_indices = bundle_indices[:self.parallel_config.world_size] + + worker_metadata: List[RayWorkerMetaData] = [] + driver_ip = get_ip() + for rank, bundle_id in enumerate(bundle_indices): + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + if current_platform.ray_device_key == "GPU": + # NV+AMD GPUs, and Intel XPUs + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, + rpc_rank=rank) + else: + worker = ray.remote( + num_cpus=0, + num_gpus=0, + resources={current_platform.ray_device_key: num_gpus}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, + rpc_rank=rank) + worker_metadata.append( + RayWorkerMetaData(worker=worker, created_rank=rank)) + + worker_ips = ray.get([ + each.worker.get_node_ip.remote() # type: ignore[attr-defined] + for each in worker_metadata + ]) + + for each, ip in zip(worker_metadata, worker_ips): + each.ip = ip + + if not self.use_ray_spmd_worker: + for i, each in enumerate(worker_metadata): + # find and remove the dummy worker from the list + worker = each.worker + worker_ip = each.ip + if self.driver_dummy_worker is None and worker_ip == driver_ip: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + vllm_config=self.vllm_config, rpc_rank=0) + worker_metadata.pop(i) + break + + logger.debug("workers: %s", worker_metadata) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node." + f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." + "Consider adjusting the Ray placement group or running " + "the driver on a GPU node.") + + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = item.ip + return (0 if ip == driver_ip else 1, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + sorted_worker_metadata = sorted(worker_metadata, + key=sort_by_driver_then_worker_ip) + start_rank = 0 if self.use_ray_spmd_worker else 1 + for i, item in enumerate(sorted_worker_metadata): + item.adjusted_rank = i + start_rank + self.workers = [item.worker for item in sorted_worker_metadata] + rerank_mapping = { + item.created_rank: item.adjusted_rank + for item in sorted_worker_metadata + } + self._run_workers("adjust_rank", rerank_mapping) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" + " each node.") + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [{ + current_platform.device_control_env_var: + ",".join(map(str, node_gpus[node_id])), + } for (node_id, _) in worker_node_and_gpu_ids] + + # Environment variables to copy from driver to workers + env_vars_to_copy = [ + v for v in envs.environment_variables + if v not in self.WORKER_SPECIFIC_ENV_VARS + and v not in self.non_carry_over_env_vars + ] + + env_vars_to_copy.extend(current_platform.additional_env_vars) + + # Copy existing env vars to each worker's args + for args in all_args_to_update_environment_variables: + # TODO: refactor platform-specific env vars + for name in env_vars_to_copy: + if name in os.environ: + args[name] = os.environ[name] + + logger.info("non_carry_over_env_vars from config: %s", + self.non_carry_over_env_vars) + logger.info( + "Copying the following environment variables to workers: %s", + [v for v in env_vars_to_copy if v in os.environ]) + logger.info( + "If certain env vars should NOT be copied to workers, add them to " + "%s file", self.non_carry_over_env_vars_file) + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + all_kwargs = [] + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): + local_rank = node_workers[node_id].index(rank) + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + all_kwargs.append(kwargs) + self._run_workers("init_worker", all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + if self.use_v1: + serialized_data = execute_model_req + else: + serialized_data = self.input_encoder.encode(execute_model_req) + outputs = ray.get(self.forward_dag.execute(serialized_data)) + if self.use_v1: + output = outputs[0] + else: + output = self.output_decoder.decode(outputs[0]) + return output + + def _run_workers( + self, + method: Union[str, Callable], + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + - args/kwargs: All workers share the same args/kwargs + """ + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method + if self.use_ray_spmd_worker: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(sent_method, *args, **kwargs) + for worker in ray_workers + ] + + if async_run_tensor_parallel_workers_only: + # Just return futures + return ray_worker_outputs + + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not self.use_ray_spmd_worker: + # Start the driver worker after all the ray workers. + driver_worker_output = [ + self.driver_worker.execute_method(sent_method, *args, **kwargs) + ] + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return driver_worker_output + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def _check_ray_cgraph_installation(self): + import importlib.metadata + + from packaging import version + + required_version = version.parse("2.43.0") + current_version = version.parse(importlib.metadata.version("ray")) + if current_version < required_version: + raise ValueError(f"Ray version {required_version} is " + f"required, but found {current_version}") + + import importlib.util + cgraph_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if cgraph_spec is None: + raise ValueError("Ray Compiled Graph is not installed. " + "Run `pip install ray[cgraph]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if (cupy_spec is None + and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"): + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. " + "Run `pip install ray[cgraph]` and check cupy installation.") + + def _compiled_ray_dag(self, enable_asyncio: bool): + assert self.parallel_config.use_ray + self._check_ray_cgraph_installation() + # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds + # (it is 10 seconds by default). This is a Ray environment variable to + # control the timeout of getting result from a compiled graph execution, + # i.e., the distributed execution that includes model forward runs and + # intermediate tensor communications, in the case of vllm. + # Note: we should set this env var before importing + # ray.dag, otherwise it will not take effect. + os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112 + from ray.dag import InputNode, MultiOutputNode + logger.info("RAY_CGRAPH_get_timeout is set to %s", + os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112 + logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE) + logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + + channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE + if channel_type not in ("auto", "nccl", "shm"): + raise ValueError( + "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: " + f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.") + + with InputNode() as input_data: + # Example DAG: PP=2, TP=4 + # + # For V0: + # ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501 + # ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501 + # + # For V1: + # SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501 + # SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + if self.use_v1: + outputs = [ + worker.execute_model_ray. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + else: + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if (pp_rank < last_pp_rank and + envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"): + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage or when using shared memory (the default). + transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE + outputs = [ + output.with_tensor_transport(transport=transport) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + + return forward_dag.experimental_compile( + enable_asyncio=enable_asyncio, + _overlap_gpu_communication=envs. + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + + def __del__(self): + self.shutdown() + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + serialized_data = self.input_encoder.encode(execute_model_req) + dag_future = await self.forward_dag.execute_async(serialized_data) + output = await dag_future[0] + return self.output_decoder.decode(output) + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + if not self.tp_driver_workers: + return await self.driver_exec_method("execute_model", + execute_model_req) + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], + "execute_model", execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) + + def check_health(self) -> None: + # Assume that the Ray workers are healthy. + # TODO: check the health of the Ray workers + return diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py new file mode 100644 index 0000000..c222f16 --- /dev/null +++ b/vllm/executor/ray_utils.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import msgspec + +import vllm.platforms +from vllm.config import ParallelConfig +from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.utils import get_ip +from vllm.worker.worker_base import WorkerWrapperBase + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.outputs import ModelRunnerOutput + +logger = init_logger(__name__) +PG_WAIT_TIMEOUT = 1800 + +try: + import ray + from ray.util import placement_group_table + from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node + + class RayWorkerWrapper(WorkerWrapperBase): + """Ray wrapper for vllm.worker.Worker, allowing Worker to be + lazily initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. + self.compiled_dag_cuda_device_set = False + + self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + device_key = vllm.platforms.current_platform.ray_device_key + if not device_key: + raise RuntimeError("current platform %s does not support ray.", + vllm.platforms.current_platform.device_name) + gpu_ids = ray.get_runtime_context().get_accelerator_ids( + )[device_key] + return node_id, gpu_ids + + def execute_model_spmd( + self, req_or_tuple: Union[bytes, + Tuple[bytes, + Optional[IntermediateTensors]]] + ) -> bytes: + """Execute model in SPMD fashion: used only when SPMD worker and + compiled DAG are both enabled. + + Args: + req_or_tuple: A request or a tuple containing the + request and intermediate tensors. Intermediate tensors are + None unless if it is provided because it is > 0 pipeline + stage. The request is serialized by msgspec. + """ + if isinstance(req_or_tuple, bytes): + serialized_req, intermediate_tensors = req_or_tuple, None + else: + serialized_req, intermediate_tensors = req_or_tuple + + execute_model_req = self.input_decoder.decode(serialized_req) + + # TODO(swang): This is needed right now because Ray Compiled Graph + # executes on a background thread, so we need to reset torch's + # current device. + if not self.compiled_dag_cuda_device_set: + current_platform.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + output = self.worker._execute_model_spmd(execute_model_req, + intermediate_tensors) + # Pipeline model request and output to the next pipeline stage. + if isinstance(output, IntermediateTensors): + output = serialized_req, output + else: + output = self.output_encoder.encode(output) + + return output + + def setup_device_if_necessary(self): + # TODO(swang): This is needed right now because Ray CG executes + # on a background thread, so we need to reset torch's current + # device. + # We can remove this API after it is fixed in compiled graph. + assert self.worker is not None, "Worker is not initialized" + if not self.compiled_dag_cuda_device_set: + if current_platform.is_tpu(): + # Not needed + pass + else: + current_platform.set_device(self.worker.device) + + self.compiled_dag_cuda_device_set = True + + def execute_model_ray( + self, + scheduler_output: Union["SchedulerOutput", + Tuple["SchedulerOutput", + "IntermediateTensors"]], + ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput", + "IntermediateTensors"]]: + # This method is used by Ray Compiled Graph to execute the model, + # and it needs a special logic of self.setup_device_if_necessary() + self.setup_device_if_necessary() + assert self.worker is not None, "Worker is not initialized" + if isinstance(scheduler_output, tuple): + scheduler_output, intermediate_tensors = scheduler_output + else: + scheduler_output, intermediate_tensors = scheduler_output, None + output = self.worker.model_runner.execute_model( + scheduler_output, intermediate_tensors) + if isinstance(output, IntermediateTensors): + output = scheduler_output, output + return output + + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + + ray_import_err = None + +except ImportError as e: + ray = None # type: ignore + ray_import_err = e + RayWorkerWrapper = None # type: ignore + + +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """Raise an exception if Ray is not available.""" + if ray is None: + raise ValueError("Failed to import Ray, please install Ray with " + "`pip install ray`.") from ray_import_err + + +def _verify_bundles(placement_group: "PlacementGroup", + parallel_config: ParallelConfig, device_str: str): + """Verify a given placement group has bundles located in the right place. + + There are 2 rules. + - Warn if all tensor parallel workers cannot fit in a single node. + - Fail if driver node is not included in a placement group. + """ + assert ray.is_initialized(), ( + "Ray is not initialized although distributed-executor-backend is ray.") + pg_data = placement_group_table(placement_group) + # bundle_idx -> node_id + bundle_to_node_ids = pg_data["bundles_to_node_id"] + # bundle_idx -> bundle (e.g., {"GPU": 1}) + bundles = pg_data["bundles"] + # node_id -> List of bundle (e.g., {"GPU": 1}) + node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + + for bundle_idx, node_id in bundle_to_node_ids.items(): + node_id_to_bundle[node_id].append(bundles[bundle_idx]) + driver_node_id = ray.get_runtime_context().get_node_id() + + if driver_node_id not in node_id_to_bundle: + raise RuntimeError( + f"driver node id {driver_node_id} is not included in a placement " + f"group {placement_group.id}. Node id -> bundles " + f"{node_id_to_bundle}. " + "You don't have enough GPUs available in a current node. Check " + "`ray status` and `ray list nodes` to see if you have available " + "GPUs in a node `{driver_node_id}` before starting an vLLM engine." + ) + + for node_id, bundles in node_id_to_bundle.items(): + if len(bundles) < parallel_config.tensor_parallel_size: + logger.warning( + "tensor_parallel_size=%d " + "is bigger than a reserved number of %ss (%d " + "%ss) in a node %s. Tensor parallel workers can be " + "spread out to 2+ nodes which can degrade the performance " + "unless you have fast interconnect across nodes, like " + "Infiniband. To resolve this issue, make sure you have more " + "than %d GPUs available at each node.", + parallel_config.tensor_parallel_size, device_str, len(bundles), + device_str, node_id, parallel_config.tensor_parallel_size) + + +def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): + """Wait until a placement group is ready. + + It prints the informative log messages if the placement group is + not created within time. + + """ + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + placement_group_specs = current_placement_group.bundle_specs + + s = time.time() + pg_ready_ref = current_placement_group.ready() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval) + if len(ready) > 0: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for creating a placement group of specs for " + "%d seconds. specs=%s. Check `ray status` and " + "`ray list nodes` to see if you have enough resources," + " and make sure the IP addresses used by ray cluster" + " are the same as VLLM_HOST_IP environment variable" + " specified in each node if you are running on a multi-node.", + int(time.time() - s), placement_group_specs) + + try: + ray.get(pg_ready_ref, timeout=0) + except ray.exceptions.GetTimeoutError: + raise ValueError( + "Cannot provide a placement group of " + f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " + "`ray status` and `ray list nodes` to make sure the cluster has " + "enough resources.") from None + + +def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): + ray.util.remove_placement_group(current_placement_group) + s = time.time() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + pg = ray.util.get_current_placement_group() + if pg is None: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for removing a placement group of specs for " + "%d seconds.", int(time.time() - s)) + time.sleep(wait_interval) + + +def initialize_ray_cluster( + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. + + Args: + parallel_config: The configurations for parallel execution. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + assert_ray_available() + from vllm.platforms import current_platform + + if ray.is_initialized(): + logger.info("Ray is already initialized. Skipping Ray initialization.") + elif current_platform.is_rocm() or current_platform.is_xpu(): + # Try to connect existing ray instance and create a new one if not found + try: + ray.init("auto") + except ConnectionError: + logger.warning( + "No existing RAY instance detected. " + "A new instance will be launched with current node resources.") + ray.init(address=ray_address, num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address) + + device_str = current_platform.ray_device_key + if not device_str: + raise ValueError( + f"current platform {current_platform.device_name} does not " + "support ray.") + + # Create or get the placement group for worker processes + if parallel_config.placement_group: + current_placement_group = parallel_config.placement_group + else: + current_placement_group = ray.util.get_current_placement_group() + + if current_placement_group: + logger.info("Using the existing placement group") + + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + device_bundles = 0 + for bundle in bundles: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 " + f"{device_str}.") + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group. " + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") + else: + logger.info("No current placement group found. " + "Creating a new placement group.") + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + # Log a warning message and delay resource allocation failure response. + # Avoid immediate rejection to allow user-initiated placement group + # created and wait cluster to be ready + if parallel_config.world_size > num_devices_in_cluster: + logger.warning( + "The number of required %ss exceeds the total " + "number of available %ss in the placement group.", device_str, + device_str) + # Create a new placement group + placement_group_specs: List[Dict[str, float]] = ([{ + device_str: 1.0 + } for _ in range(parallel_config.world_size)]) + + # vLLM engine is also a worker to execute model with an accelerator, + # so it requires to have the device in a current node. Check if + # the current node has at least one device. + current_ip = get_ip() + current_node_id = ray.get_runtime_context().get_node_id() + current_node_resource = available_resources_per_node()[current_node_id] + if current_node_resource.get(device_str, 0) < 1: + raise ValueError( + f"Current node has no {device_str} available. " + f"{current_node_resource=}. vLLM engine cannot start without " + f"{device_str}. Make sure you have at least 1 {device_str} " + f"available in a node {current_node_id=} {current_ip=}.") + # This way, at least bundle is required to be created in a current + # node. + placement_group_specs[0][f"node:{current_ip}"] = 0.001 + + # By default, Ray packs resources as much as possible. + current_placement_group = ray.util.placement_group( + placement_group_specs, strategy="PACK") + _wait_until_pg_ready(current_placement_group) + + assert current_placement_group is not None + _verify_bundles(current_placement_group, parallel_config, device_str) + # Set the placement group in the parallel config + parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py new file mode 100644 index 0000000..7ebeb4a --- /dev/null +++ b/vllm/executor/uniproc_executor.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist + +import vllm.envs as envs +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + run_method) +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class UniProcExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + local_rank = 0 + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + if len(device_info) > 1: + local_rank = int(device_info[1]) + rank = 0 + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + if kwargs is None: + kwargs = {} + answer = run_method(self.driver_worker, method, args, kwargs) + return [answer] + + def check_health(self) -> None: + # UniProcExecutor will always be healthy as long as + # it's running. + return + + +UniProcExecutorAsync = UniProcExecutor + + +class ExecutorWithExternalLauncher(UniProcExecutor): + """An executor that uses external launchers to launch engines, + specially designed for torchrun-compatible launchers, for + offline inference with tensor parallelism. + + see https://github.com/vllm-project/vllm/issues/11400 for + the motivation, and examples/offline_inference/torchrun_example.py + for the usage example. + + The key idea: although it is tensor-parallel inference, we only + create one worker per executor, users will launch multiple + engines with torchrun-compatible launchers, and all these engines + work together to process the same prompts. When scheduling is + deterministic, all the engines will generate the same outputs, + and they don't need to synchronize the states with each other. + """ + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ + ("ExecutorWithExternalLauncher needs deterministic " + "execution, so it" + "does not support delay_factor in scheduling") + if envs.VLLM_USE_V1: + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ + ("To get deterministic execution in V1, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + # engines are launched in torchrun-compatible launchers + # so we can use the env:// method. + # required env vars: + # - RANK + # - LOCAL_RANK + # - MASTER_ADDR + # - MASTER_PORT + distributed_init_method = "env://" + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """ + Determine the number of available KV blocks. + Add an additional all_reduce to get the min across all ranks. + Note that even if we have the same `gpu_memory_utilization` and + `swap_space`, the available memory in every rank might still + differ because NCCL can take different amounts of memory in + different ranks. Therefore, it is necessary to test if all ranks + agree on the same KV cache configuration. + """ + a, b = super().determine_num_available_blocks() + from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group + a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) + b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) + dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN) + return a_tensor.item(), b_tensor.item() diff --git a/vllm/forward_context.py b/vllm/forward_context.py new file mode 100644 index 0000000..d899339 --- /dev/null +++ b/vllm/forward_context.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch +import torch.distributed as dist + +import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.two_batch_overlap.forward_context import get_tbo_forward_context, set_tbo_forward_context + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + +logger = init_logger(__name__) + +track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 +last_logging_time: float = 0 +forward_start_time: float = 0 +batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL +batchsize_forward_time: defaultdict = defaultdict(list) + +@dataclass +class DPMetadata: + max_tokens_across_dp_cpu: torch.Tensor + cu_tokens_across_dp_cpu: torch.Tensor + + @staticmethod + def num_tokens_across_dp(num_tokens: int, dp_size: int, + dp_rank: int) -> torch.Tensor: + """ + Gather the num_tokens across all DP ranks and return results in a + CPU tensor of size dp_size. + """ + num_tokens_across_dp = [0] * dp_size + num_tokens_across_dp[dp_rank] = num_tokens + num_tokens_tensor = torch.tensor(num_tokens_across_dp, + device="cpu", + dtype=torch.int32) + from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + return num_tokens_tensor + + @staticmethod + def make( + parallel_config: ParallelConfig, + attn_metadata: Any, + num_tokens: int, + num_tokens_across_dp: Optional[torch.Tensor] = None + ) -> "DPMetadata": + + assert parallel_config.data_parallel_size > 1 + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + if attn_metadata is not None and hasattr(attn_metadata, + "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends or no attn_metadata + batchsize = num_tokens + + # If num_tokens_across_dp is None, it will be computed by all_reduce + # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize + assert (num_tokens_across_dp is None + or num_tokens_across_dp[dp_rank] == batchsize) + if num_tokens_across_dp is None: + num_tokens_across_dp = DPMetadata.num_tokens_across_dp( + batchsize, dp_size, dp_rank) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) + cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) + return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu) + + +@dataclass +class ForwardContext: + # copy from vllm_config.compilation_config.static_forward_context + no_compile_layers: dict[str, Any] + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, map from layer_name of each + attention layer to its attention metadata + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] + # TODO: remove after making all virtual_engines share the same kv cache + virtual_engine: int # set dynamically for each forward pass + # set dynamically for each forward pass + dp_metadata: Optional[DPMetadata] = None + skip_cuda_graphs: bool = False + + +_forward_context: Optional[ForwardContext] = None + + +def get_forward_context() -> ForwardContext: + if envs.VLLM_ENABLE_TBO: + forward_context = get_tbo_forward_context() + """Get the current forward context.""" + assert forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context.") + return forward_context + + """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context.") + return _forward_context + + +@contextmanager +def set_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + skip_cuda_graphs: bool = False, +): + """A context manager that stores the current forward context, + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ + global forward_start_time + need_to_track_batchsize = track_batchsize and attn_metadata is not None + if need_to_track_batchsize: + forward_start_time = time.perf_counter() + dp_metadata: Optional[DPMetadata] = None + if vllm_config.parallel_config.data_parallel_size > 1 and ( + attn_metadata is not None or num_tokens is not None): + dp_metadata = DPMetadata.make(vllm_config.parallel_config, + attn_metadata, num_tokens or 0, + num_tokens_across_dp) + + global _forward_context + prev_context = _forward_context + _forward_context = ForwardContext( + no_compile_layers=vllm_config.compilation_config. + static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata, + skip_cuda_graphs=skip_cuda_graphs, + ) + if envs.VLLM_ENABLE_TBO: + set_tbo_forward_context(_forward_context) + + try: + yield + finally: + global last_logging_time, batchsize_logging_interval + if need_to_track_batchsize: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens + else: + # for v1 attention backends + batchsize = num_tokens + # we use synchronous scheduling right now, + # adding a sync point here should not affect + # scheduling of the next batch + from vllm.platforms import current_platform + synchronize = current_platform.synchronize + if synchronize is not None: + synchronize() + now = time.perf_counter() + # time measurement is in milliseconds + batchsize_forward_time[batchsize].append( + (now - forward_start_time) * 1000) + if now - last_logging_time > batchsize_logging_interval: + last_logging_time = now + forward_stats = [] + for bs, times in batchsize_forward_time.items(): + if len(times) <= 1: + # can be cudagraph / profiling run + continue + medium = torch.quantile(torch.tensor(times), q=0.5).item() + medium = round(medium, 2) + forward_stats.append((bs, len(times), medium)) + forward_stats.sort(key=lambda x: x[1], reverse=True) + if forward_stats: + logger.info(("Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s"), + forward_stats) + + _forward_context = prev_context + if envs.VLLM_ENABLE_TBO: + set_tbo_forward_context(_forward_context) + + +_profiling: bool = False + +@contextmanager +def set_profilling(profiling): + global _profiling + _profiling = profiling + + +def get_profilling() -> bool: + global _profiling + return _profiling \ No newline at end of file diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py new file mode 100644 index 0000000..37bf2b7 --- /dev/null +++ b/vllm/inputs/__init__.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) +from .registry import (DummyData, InputContext, InputProcessingContext, + InputRegistry) + +INPUT_REGISTRY = InputRegistry() +""" +The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used +by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the +target model. +""" + +__all__ = [ + "TextPrompt", + "TokensPrompt", + "PromptType", + "SingletonPrompt", + "ExplicitEncoderDecoderPrompt", + "TokenInputs", + "EmbedsInputs", + "token_inputs", + "embeds_inputs", + "DecoderOnlyInputs", + "EncoderDecoderInputs", + "ProcessorInputs", + "SingletonInputs", + "build_explicit_enc_dec_prompt", + "to_enc_dec_tuple_list", + "zip_enc_dec_prompts", + "INPUT_REGISTRY", + "DummyData", + "InputContext", + "InputProcessingContext", + "InputRegistry", +] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py new file mode 100644 index 0000000..d143999 --- /dev/null +++ b/vllm/inputs/data.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast + +import torch +from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar + +if TYPE_CHECKING: + from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: list[int] + """A list of token IDs to pass to the model.""" + + token_type_ids: NotRequired[list[int]] + """A list of token type IDs to pass to the cross encoder model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +class EmbedsPrompt(TypedDict): + """Schema for a prompt provided via token embeddings.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] +""" +Set of possible schemas for a single prompt: + +- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) +- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) + +Note that "singleton" is as opposed to a data structure +which encapsulates multiple prompts, i.e. of the sort +which may be utilized for encoder/decoder models when +the user desires to express both the encoder & decoder +prompts explicitly, i.e. +[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] + +A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be +employed as (1) input to a decoder-only model, (2) input to +the encoder of an encoder/decoder model, in the scenario +where the decoder-prompt is not specified explicitly, or +(3) as a member of a larger data structure encapsulating +more than one prompt, i.e. +[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] +""" + + +def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt) + + +def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: + return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt) + + +_T1_co = TypeVar("_T1_co", + bound=SingletonPrompt, + default=SingletonPrompt, + covariant=True) +_T2_co = TypeVar("_T2_co", + bound=SingletonPrompt, + default=SingletonPrompt, + covariant=True) + + +# TODO: Make fields ReadOnly once mypy supports it +class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. + + The encoder and decoder prompts, respectively, may be formatted + according to any of the + [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas, + and are not required to have the same schema. + + Only the encoder prompt may have multi-modal data. mm_processor_kwargs + should be at the top-level, and should not be set in the encoder/decoder + prompts, since they are agnostic to the encoder/decoder. + + Note that an + [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] + may not be used as an input to a decoder-only model, + and that the `encoder_prompt` and `decoder_prompt` + fields of this data structure themselves must be + [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances. + """ + + encoder_prompt: _T1_co + + decoder_prompt: Optional[_T2_co] + + mm_processor_kwargs: NotRequired[dict[str, Any]] + + +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +""" +Set of possible schemas for an LLM input, including +both decoder-only and encoder/decoder input types: + +- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) +- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) +- A single data structure containing both an encoder and a decoder prompt + ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]) +""" + + +class TokenInputs(TypedDict): + """Represents token-based inputs.""" + + type: Literal["token"] + """The type of inputs.""" + + prompt_token_ids: list[int] + """The token IDs of the prompt.""" + + token_type_ids: NotRequired[list[int]] + """The token type IDs of the prompt.""" + + prompt: NotRequired[str] + """ + The original prompt text corresponding to the token IDs, if available. + """ + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +def token_inputs( + prompt_token_ids: list[int], + token_type_ids: Optional[list[int]] = None, + prompt: Optional[str] = None, + cache_salt: Optional[str] = None, +) -> TokenInputs: + """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional + values.""" + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + +class EmbedsInputs(TypedDict): + """Represents embeddings-based inputs.""" + + type: Literal["embeds"] + """The type of inputs.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +def embeds_inputs( + prompt_embeds: torch.Tensor, + cache_salt: Optional[str] = None, +) -> EmbedsInputs: + """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional + values.""" + inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) + + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] +""" +The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are +passed to the model executor. +This specifies the data required for decoder-only models. +""" + + +class EncoderDecoderInputs(TypedDict): + """ + The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they + are passed to the model executor. + + This specifies the required data for encoder-decoder models. + """ + + encoder: Union[TokenInputs, "MultiModalInputs"] + """The inputs for the encoder portion.""" + + decoder: Union[TokenInputs, "MultiModalInputs"] + """The inputs for the decoder portion.""" + + +SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] +""" +A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be +passed to [`vllm.sequence.Sequence`][]. +""" + +ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +""" +The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][]. +""" + +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) + + +def build_explicit_enc_dec_prompt( + encoder_prompt: _T1, + decoder_prompt: Optional[_T2], + mm_processor_kwargs: Optional[dict[str, Any]] = None, +) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt, + mm_processor_kwargs=mm_processor_kwargs, + ) + + +def zip_enc_dec_prompts( + enc_prompts: Iterable[_T1], + dec_prompts: Iterable[Optional[_T2]], + mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]], + dict[str, Any]]] = None, +) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: + """ + Zip encoder and decoder prompts together into a list of + [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] + instances. + + ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same + dictionary will be used for every encoder/decoder prompt. If an iterable is + provided, it will be zipped with the encoder/decoder prompts. + """ + if mm_processor_kwargs is None: + mm_processor_kwargs = cast(dict[str, Any], {}) + if isinstance(mm_processor_kwargs, dict): + return [ + build_explicit_enc_dec_prompt( + encoder_prompt, + decoder_prompt, + cast(dict[str, Any], mm_processor_kwargs), + ) for (encoder_prompt, + decoder_prompt) in zip(enc_prompts, dec_prompts) + ] + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs + ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) + ] + + +def to_enc_dec_tuple_list( + enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], +) -> list[tuple[_T1, Optional[_T2]]]: + return [(enc_dec_prompt["encoder_prompt"], + enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts] \ No newline at end of file diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py new file mode 100644 index 0000000..8c37007 --- /dev/null +++ b/vllm/inputs/parse.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Literal, Optional, TypedDict, Union, cast, overload + +from typing_extensions import TypeIs + +from vllm.utils import is_list_of + +from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, TextPrompt, + TokensPrompt) + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: list[int] + is_tokens: Literal[True] + + +@overload +def parse_and_batch_prompt( + prompt: Union[str, list[str]], ) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, list[str], list[int], list[list[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt, str): + # case 2: array of strings + prompt = cast(list[str], prompt) + return [ + ParsedText(content=elem, is_tokens=False) for elem in prompt + ] + if is_list_of(prompt, int): + # case 3: array of tokens + prompt = cast(list[int], prompt) + return [ParsedTokens(content=prompt, is_tokens=True)] + if is_list_of(prompt, list): + prompt = cast(list[list[int]], prompt) + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt[0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in prompt + ] + + raise TypeError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class ParsedStrPrompt(TypedDict): + type: Literal["str"] + content: str + + +class ParsedTextPrompt(TypedDict): + type: Literal["text"] + content: TextPrompt + + +class ParsedTokensPrompt(TypedDict): + type: Literal["tokens"] + content: TokensPrompt + + +class ParsedEmbedsPrompt(TypedDict): + type: Literal["embeds"] + content: EmbedsPrompt + + +ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt, + ParsedTokensPrompt, ParsedEmbedsPrompt] + + +@overload +def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: + ... + + +def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + # Type ignores are because mypy does not correctly infer the TypedDicts + # Pyright does succeed. + if "prompt_embeds" in prompt: + return ParsedEmbedsPrompt( + type="embeds", content=prompt) # type: ignore[typeddict-item] + elif "prompt_token_ids" in prompt: + return ParsedTokensPrompt( + type="tokens", content=prompt) # type: ignore[typeddict-item] + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) + raise TypeError( + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") + + +def is_explicit_encoder_decoder_prompt( + prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt + + +def split_enc_dec_inputs( + inputs: ProcessorInputs, +) -> tuple[Optional[SingletonInputs], SingletonInputs]: + if "encoder" in inputs and "decoder" in inputs: + # NOTE: This passes pyright but not mypy + return ( + inputs["encoder"], # type: ignore[typeddict-item] + inputs["decoder"], # type: ignore[typeddict-item] + ) + + return None, inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py new file mode 100644 index 0000000..e24959b --- /dev/null +++ b/vllm/inputs/preprocess.py @@ -0,0 +1,927 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from collections.abc import Mapping +from typing import Any, Optional, Union, cast + +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, + MultiModalInputs) +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup + +from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, + EncoderDecoderInputs, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, embeds_inputs, token_inputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt + +logger = init_logger(__name__) + + +class InputPreprocessor: + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[TokenizerGroup], + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ) -> None: + super().__init__() + + self.model_config = model_config + self.tokenizer = tokenizer + self.mm_registry = mm_registry + + def get_tokenizer_group(self) -> TokenizerGroup: + if self.tokenizer is None: + raise ValueError("You cannot pass text prompts when " + "`skip_tokenizer_init` is True") + + return self.tokenizer + + def get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def get_decoder_start_token_id(self) -> Optional[int]: + """ + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + """ + + if not self.model_config.is_encoder_decoder: + logger.warning_once( + "Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if self.model_config is None or self.model_config.hf_config is None: + logger.warning_once( + "Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + "decoder_start_token_id", None) + if dec_start_token_id is None: + logger.warning_once( + "Falling back on for decoder start token " + "id because decoder start token id is not " + "available.") + dec_start_token_id = self.get_bos_token_id() + + return dec_start_token_id + + def _get_default_enc_dec_decoder_prompt(self) -> list[int]: + """ + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + """ + + bos_token_id = self.get_bos_token_id() + assert bos_token_id is not None + return [bos_token_id] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[list[int]], + ) -> list[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on: + https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py + specifically, + `GenerationMixin._prepare_decoder_input_ids_for_generation()`. + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id = self.get_decoder_start_token_id() + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _apply_prompt_adapter( + self, + prompt_token_ids: list[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> list[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + + def _get_tokenization_kw( + self, + overrides: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + kwargs = dict[str, Any]() + + if self.model_config.hf_config.model_type == "whisper": + # For Whisper, special tokens should be provided by the user based + # on the task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + kwargs["add_special_tokens"] = False + + if overrides: + kwargs.update(overrides) + + return kwargs + + def _tokenize_prompt( + self, + prompt: str, + lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + ) -> list[int]: + """ + Apply the model's tokenizer to a text prompt, returning the + corresponding token IDs. + """ + tokenizer = self.get_tokenizer_group() + tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) + + encoder_config = self.model_config.encoder_config + + if encoder_config and encoder_config.get("do_lower_case", False): + prompt = prompt.lower() + + if self.model_config.tokenizer_mode == "cpm": + return [tokenizer.bos_id] + tokenizer.encode(prompt) + else: + return tokenizer.encode(prompt=prompt, + lora_request=lora_request, + **tokenization_kwargs) + + async def _tokenize_prompt_async( + self, + prompt: str, + lora_request: Optional[LoRARequest], + tokenization_kwargs: Optional[dict[str, Any]] = None, + ) -> list[int]: + """ + Async version of + [`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt]. + """ + tokenizer = self.get_tokenizer_group() + tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs) + + return await tokenizer.encode_async(prompt=prompt, + lora_request=lora_request, + **tokenization_kwargs) + + def _get_mm_tokenizer( + self, + lora_request: Optional[LoRARequest], + ) -> AnyTokenizer: + # PrithviGeoSpatialMAE needs to be initialized without a tokenizer + # while using also multi-modal input + if not self.tokenizer: + return cast(AnyTokenizer, object()) # Dummy + + tokenizer_group = self.get_tokenizer_group() + return tokenizer_group.get_lora_tokenizer(lora_request) + + async def _get_mm_tokenizer_async( + self, + lora_request: Optional[LoRARequest], + ) -> AnyTokenizer: + # PrithviGeoSpatialMAE needs to be initialized without a tokenizer + # while using also multi-modal input + if not self.tokenizer: + return cast(AnyTokenizer, object()) # Dummy + + tokenizer_group = self.get_tokenizer_group() + return await tokenizer_group.get_lora_tokenizer_async(lora_request) + + def _process_multimodal( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Apply the model's multi-modal processor to a multi-modal prompt, + returning the corresponding token IDs and metadata. + """ + tokenizer = self._get_mm_tokenizer(lora_request) + + mm_processor = self.mm_registry.create_processor(self.model_config, + tokenizer=tokenizer) + + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes) + + async def _process_multimodal_async( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Mapping[str, object]], + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Async version of + [`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal]. + """ + tokenizer = await self._get_mm_tokenizer_async(lora_request) + + mm_processor = self.mm_registry.create_processor(self.model_config, + tokenizer=tokenizer) + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + + return mm_processor.apply(prompt, + mm_data, + hf_processor_mm_kwargs=mm_processor_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes) + + def _process_embeds( + self, + parsed_content: EmbedsPrompt, + ) -> EmbedsInputs: + if not self.model_config.enable_prompt_embeds: + raise ValueError("You must set `--enable-prompt-embeds` to input " + "`prompt_embeds`.") + + prompt_embeds = parsed_content["prompt_embeds"] + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return embeds_inputs(prompt_embeds=prompt_embeds, + cache_salt=parsed_content.get("cache_salt")) + + async def _process_embeds_async( + self, + parsed_content: EmbedsPrompt, + ) -> EmbedsInputs: + return self._process_embeds(parsed_content) + + def _process_tokens( + self, + parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + inputs = token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _process_tokens_async( + self, + parsed_content: TokensPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_token_ids = parsed_content["prompt_token_ids"] + token_type_ids = parsed_content.get("token_type_ids") + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_token_ids, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + inputs = token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _process_text( + self, + parsed_content: TextPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_text = parsed_content["prompt"] + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = self._process_multimodal( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + prompt_token_ids = self._tokenize_prompt( + prompt_text, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + async def _process_text_async( + self, + parsed_content: TextPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> Union[TokenInputs, MultiModalInputs]: + prompt_text = parsed_content["prompt"] + + inputs: Union[TokenInputs, MultiModalInputs] + if multi_modal_data := parsed_content.get("multi_modal_data"): + inputs = await self._process_multimodal_async( + prompt_text, + multi_modal_data, + parsed_content.get("mm_processor_kwargs"), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + else: + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + inputs = token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if cache_salt := parsed_content.get("cache_salt"): + inputs["cache_salt"] = cache_salt + + return inputs + + def _prompt_to_llm_inputs( + self, + prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> SingletonInputs: + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + * return_mm_hashes: whether to return multimodal hashes + + Returns: + + * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance + """ + parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "embeds": + return self._process_embeds(parsed["content"]) + if parsed["type"] == "tokens": + return self._process_tokens( + parsed["content"], + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "text": + return self._process_text( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return self._process_text( + TextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + assert_never(parsed) + + async def _prompt_to_llm_inputs_async( + self, + prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + return_mm_hashes: bool = False, + ) -> SingletonInputs: + """ + Async version of + [`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs]. + """ + parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "embeds": + return await self._process_embeds_async(parsed["content"]) + if parsed["type"] == "tokens": + return await self._process_tokens_async( + parsed["content"], + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "text": + return await self._process_text_async( + parsed["content"], + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + if parsed["type"] == "str": + return await self._process_text_async( + TextPrompt(prompt=parsed["content"]), + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + assert_never(parsed) + + def _build_enc_dec_llm_inputs( + self, + encoder_inputs: SingletonInputs, + decoder_inputs: Optional[SingletonInputs], + ) -> EncoderDecoderInputs: + if (encoder_inputs["type"] == "embeds" + or decoder_inputs and decoder_inputs["type"] == "embeds"): + raise ValueError("Embedding inputs are not supported for encoder-" + "decoder models") + + # Needed for mypy + encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], + encoder_inputs) + decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]], + decoder_inputs) + + if decoder_inputs is None: + if self.model_config.hf_config.model_type == "whisper": + # For Whisper models, the text prompt should go to the decoder. + # If no explicit encoder/decoder inputs, then copy the prompt + # from the encoder to the decoder. The encoder tokens are later + # overridden by the audio features. + dec_token_ids = encoder_inputs["prompt_token_ids"].copy() + else: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) + decoder_inputs = token_inputs(dec_token_ids) + else: + if "multi_modal_data" in decoder_inputs: + raise ValueError("Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet") + + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] = dec_token_ids + + return EncoderDecoderInputs( + encoder=encoder_inputs, + decoder=decoder_inputs, + ) + + def _split_enc_dec_mm_inputs( + self, + inputs: Union[SingletonInputs, MultiModalEncDecInputs], + decoder_inputs_to_override: Optional[SingletonInputs] = None, + ) -> tuple[SingletonInputs, SingletonInputs]: + """ + For encoder/decoder models only: + Separate Encoder/Decoder inputs from a MultiModalEncDecInputs + """ + if (inputs["type"] == "embeds" or decoder_inputs_to_override + and decoder_inputs_to_override["type"] == "embeds"): + raise ValueError("Embedding inputs are not supported for encoder-" + "decoder models") + + # Needed for mypy + inputs = cast( + Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs], + inputs, + ) + decoder_inputs_to_override = cast( + Optional[Union[TokenInputs, MultiModalInputs]], + decoder_inputs_to_override, + ) + + encoder_inputs: SingletonInputs + decoder_inputs: SingletonInputs + + if inputs["type"] == "multimodal": # Multimodal data inputs + if not ("encoder_prompt" in inputs + and "encoder_prompt_token_ids" in inputs): + raise RuntimeError("You should register an encoder-decoder " + "multi-modal processor for encoder-decoder " + "models.") + inputs = cast(MultiModalEncDecInputs, inputs) + + encoder_inputs = token_inputs( + prompt=inputs["encoder_prompt"], + prompt_token_ids=inputs["encoder_prompt_token_ids"], + ) + + decoder_prompt_inputs = decoder_inputs_to_override or inputs + decoder_inputs = MultiModalInputs( + type="multimodal", + prompt=decoder_prompt_inputs.get("prompt", ""), + prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], + mm_kwargs=inputs["mm_kwargs"], + mm_hashes=inputs["mm_hashes"], + mm_placeholders=inputs["mm_placeholders"], + ) + if cache_salt := inputs.get("cache_salt"): + decoder_inputs["cache_salt"] = cache_salt + + elif inputs["type"] == "token": # Text-only inputs + encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) + decoder_inputs = decoder_inputs_to_override or inputs + else: + assert_never(inputs) # type: ignore[arg-type] + + return encoder_inputs, decoder_inputs + + def _process_encoder_decoder_prompt( + self, + prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, + ) -> EncoderDecoderInputs: + """ + For encoder/decoder models only: + Process an input prompt into an + [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] + instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * prompt: an input prompt + + Returns: + + * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] + instance + """ + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] + + if is_explicit_encoder_decoder_prompt(prompt): + encoder_inputs = self._prompt_to_llm_inputs( + prompt["encoder_prompt"], + tokenization_kwargs=tokenization_kwargs, + ) + if (decoder_input := prompt["decoder_prompt"]) is None: + decoder_inputs = None + else: + decoder_inputs = self._prompt_to_llm_inputs(decoder_input) + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model: + encoder_inputs, decoder_inputs = ( + self._split_enc_dec_mm_inputs(encoder_inputs, + decoder_inputs)) + else: + inputs = self._prompt_to_llm_inputs( + prompt, + tokenization_kwargs=tokenization_kwargs, + ) + if self.model_config.is_multimodal_model: + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._split_enc_dec_mm_inputs(inputs)) + else: + encoder_inputs = inputs + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) + + async def _process_encoder_decoder_prompt_async( + self, + prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, + ) -> EncoderDecoderInputs: + """ + Async version of + [`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt]. + """ + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] + + if is_explicit_encoder_decoder_prompt(prompt): + encoder_task = self._prompt_to_llm_inputs_async( + prompt["encoder_prompt"], + tokenization_kwargs=tokenization_kwargs, + ) + + if (decoder_input := prompt["decoder_prompt"]) is None: + encoder_inputs = await encoder_task + decoder_inputs = None + else: + decoder_task = self._prompt_to_llm_inputs_async( + decoder_input, + tokenization_kwargs=tokenization_kwargs, + ) + + encoder_inputs, decoder_inputs = await asyncio.gather( + encoder_task, decoder_task) + + # For multimodal model, override decoder prompt from processor + # with explicit decoder prompt. + if self.model_config.is_multimodal_model: + encoder_inputs, decoder_inputs = ( + self._split_enc_dec_mm_inputs(encoder_inputs, + decoder_inputs)) + else: + inputs = await self._prompt_to_llm_inputs_async( + prompt, + tokenization_kwargs=tokenization_kwargs, + ) + if self.model_config.is_multimodal_model: + # Encoder-Decoder Multimodal model + encoder_inputs, decoder_inputs = ( + self._split_enc_dec_mm_inputs(inputs)) + else: + encoder_inputs = inputs + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) + + def _build_decoder_only_llm_inputs( + self, + prompt_inputs: DecoderOnlyInputs, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> DecoderOnlyInputs: + if "prompt_token_ids" in prompt_inputs: + prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], + prompt_inputs) # Needed for mypy + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) + + return prompt_inputs + + def _process_decoder_only_prompt( + self, + prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> DecoderOnlyInputs: + """ + For decoder-only models: + Process an input prompt into a + [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance. + + Arguments: + + * prompt: input prompt + * lora_request + * prompt_adapter_request + * return_mm_hashes + + Returns: + + * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance + """ + + prompt_comps = self._prompt_to_llm_inputs( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _process_decoder_only_prompt_async( + self, + prompt: SingletonPrompt, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> DecoderOnlyInputs: + """ + Async version of + [`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt]. + """ + prompt_comps = await self._prompt_to_llm_inputs_async( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + return_mm_hashes=return_mm_hashes, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + def preprocess( + self, + prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> ProcessorInputs: + """Preprocess the input prompt.""" + if self.model_config.is_encoder_decoder: + assert not return_mm_hashes, ( + "Multimodal hashes for encoder-decoder models should not be ", + "returned until they are supported on vLLM V1.") + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self._process_encoder_decoder_prompt( + prompt, tokenization_kwargs) + + if is_explicit_encoder_decoder_prompt(prompt): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return self._process_decoder_only_prompt( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + return_mm_hashes=return_mm_hashes, + ) + + async def preprocess_async( + self, + prompt: PromptType, + tokenization_kwargs: Optional[dict[str, Any]] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + return_mm_hashes: bool = False, + ) -> ProcessorInputs: + """ + Async version of + [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. + """ + if self.model_config.is_encoder_decoder: + assert not return_mm_hashes, ( + "Multimodal hashes for encoder-decoder models should not be ", + "returned until they are supported on vLLM V1.") + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return await self._process_encoder_decoder_prompt_async(prompt) + + if is_explicit_encoder_decoder_prompt(prompt): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return await self._process_decoder_only_prompt_async( + prompt, + tokenization_kwargs=tokenization_kwargs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + return_mm_hashes=return_mm_hashes, + ) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py new file mode 100644 index 0000000..fc6e190 --- /dev/null +++ b/vllm/inputs/registry.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union + +import torch +from packaging.version import Version +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin +from transformers import __version__ as TRANSFORMERS_VERSION +from typing_extensions import TypeVar + +from vllm.jsontree import JSONTree, json_map_leaves +from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import resolve_mm_processor_kwargs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, + MultiModalRegistry) + from vllm.sequence import SequenceData + +_T = TypeVar("_T") +_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) +_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class InputContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + def get_hf_config( + self, + typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig, + /, + ) -> _C: + """ + Get the HuggingFace configuration + (`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the configuration is not of the specified type. + """ + hf_config = self.model_config.hf_config + if not isinstance(hf_config, typ): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}") + + return hf_config + + def get_hf_image_processor_config(self) -> dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + return self.model_config.hf_image_processor_config + + def get_mm_config(self): + """ + Get the multimodal config of the model. + + Raises: + RuntimeError: If the model is not a multimodal model. + """ + mm_config = self.model_config.multimodal_config + if mm_config is None: + raise RuntimeError("Not a multimodal model") + + return mm_config + + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + /, + **kwargs: object, + ) -> _P: + """ + Get the HuggingFace processor + (`transformers.ProcessorMixin`) of the model, + additionally checking its type. + + Raises: + TypeError: If the processor is not of the specified type. + """ + return cached_processor_from_config( + self.model_config, + processor_cls=typ, + **kwargs, + ) + + def init_processor( + self, + typ: type[_T], + /, + **kwargs: object, + ) -> _T: + """ + Initialize a HuggingFace-like processor class, merging the + keyword arguments with those in the model's configuration. + """ + mm_config = self.model_config.get_multimodal_config() + base_kwargs = mm_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + return typ(**merged_kwargs) + + +@dataclass(frozen=True) +class InputProcessingContext(InputContext): + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, + /, + **kwargs: object, + ) -> _P: + # Transformers 4.53.0 has issue with passing tokenizer to + # initialize processor. We disable it for this version. + # See: https://github.com/vllm-project/vllm/issues/20224 + if Version(TRANSFORMERS_VERSION) != Version("4.53.0"): + kwargs["tokenizer"] = self.tokenizer + return super().get_hf_processor( + typ, + **kwargs, + ) + + def call_hf_processor( + self, + hf_processor: ProcessorMixin, + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, + ) -> Union[BatchFeature, JSONTree]: + """ + Call `hf_processor` on the prompt `data` + (text, image, audio...) with configurable options `kwargs`. + """ + assert callable(hf_processor) + + mm_config = self.model_config.get_multimodal_config() + base_kwargs = mm_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = resolve_mm_processor_kwargs( + base_kwargs, + kwargs, + hf_processor, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + def maybe_cast_dtype(x): + # This mimics the behavior of transformers.BatchFeature + if isinstance(x, torch.Tensor) and x.is_floating_point(): + return x.to(dtype=self.model_config.dtype) + return x + + try: + output = hf_processor(**data, **merged_kwargs, return_tensors="pt") + # this emulates output.to(dtype=self.model_config.dtype) + if isinstance(output, BatchFeature): + cast_output = json_map_leaves(maybe_cast_dtype, output.data) + return BatchFeature(cast_output) + + cast_output = json_map_leaves(maybe_cast_dtype, output) + + logger.warning_once( + f"{type(hf_processor).__name__} did not return `BatchFeature`. " + "Make sure to match the behaviour of `ProcessorMixin` when " + "implementing custom processors.") + return cast_output + + except Exception as exc: + msg = (f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={merged_kwargs}") + + raise ValueError(msg) from exc + + +class DummyData(NamedTuple): + """ + Dummy data used for profiling. + + Note: This is only used in V0. + """ + + seq_data: "SequenceData" + multi_modal_data: Optional["MultiModalDataDict"] = None + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None + + +class InputRegistry: + """ + Note: This is only used in V0. + """ + + def dummy_data_for_profiling( + self, + model_config: "ModelConfig", + seq_len: int, + mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, + ) -> DummyData: + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + """ + # Avoid circular import + from vllm.sequence import SequenceData + + if not model_config.is_multimodal_model: + seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) + return DummyData(seq_data=seq_data) + + # Encoder dummy data does not contain multi-modal data + if is_encoder_data: + enc_data = mm_registry.get_encoder_dummy_data( + model_config, seq_len) + seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) + return DummyData(seq_data=seq_data) + + dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) + + return DummyData( + seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), + multi_modal_data=dec_data.multi_modal_data, + multi_modal_placeholders=dec_data.multi_modal_placeholders, + ) diff --git a/vllm/jsontree.py b/vllm/jsontree.py new file mode 100644 index 0000000..4cbe0f7 --- /dev/null +++ b/vllm/jsontree.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Helper functions to work with nested JSON structures.""" +from collections.abc import Iterable +from functools import reduce +from typing import Callable, TypeVar, Union, overload + +_T = TypeVar("_T") +_U = TypeVar("_U") + +JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], _T] +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + + +def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: + """Iterate through each leaf in a nested JSON structure.""" + if isinstance(value, dict): + for v in value.values(): + yield from json_iter_leaves(v) + elif isinstance(value, (list, tuple)): + for v in value: + yield from json_iter_leaves(v) + else: + yield value + + +def json_map_leaves( + func: Callable[[_T], _U], + value: JSONTree[_T], +) -> JSONTree[_U]: + """Apply a function to each leaf in a nested JSON structure.""" + if isinstance(value, dict): + return {k: json_map_leaves(func, v) for k, v in value.items()} + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: JSONTree[_T], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_U, _T], _U], + value: JSONTree[_T], + initial: _U, + /, +) -> _U: + ... + + +def json_reduce_leaves( + func: Callable[..., Union[_T, _U]], + value: JSONTree[_T], + initial: _U = ..., # type: ignore[assignment] + /, +) -> Union[_T, _U]: + """ + Apply a function of two arguments cumulatively to each leaf in a + nested JSON structure, from left to right, so as to reduce the + sequence to a single value. + """ + if initial is ...: + return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type] + + return reduce( + func, # type: ignore[arg-type] + json_iter_leaves(value), + initial, + ) diff --git a/vllm/logger.py b/vllm/logger.py new file mode 100644 index 0000000..0ddb83c --- /dev/null +++ b/vllm/logger.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Logging configuration for vLLM.""" +import datetime +import json +import logging +import os +import sys +from collections.abc import Hashable +from functools import lru_cache, partial +from logging import Logger +from logging.config import dictConfig +from os import path +from types import MethodType +from typing import Any, Optional, cast + +import vllm.envs as envs + +VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING +VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH +VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL +VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX + +_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " + "[%(filename)s:%(lineno)d] %(message)s") +_DATE_FORMAT = "%m-%d %H:%M:%S" + +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "vllm": { + "class": "vllm.logging_utils.NewLineFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "vllm": { + "class": "logging.StreamHandler", + "formatter": "vllm", + "level": VLLM_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagate": False, + }, + }, + "version": 1, + "disable_existing_loggers": False +} + + +@lru_cache +def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.info(msg, *args, stacklevel=2) + + +@lru_cache +def _print_warning_once(logger: Logger, msg: str, *args: Hashable) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.warning(msg, *args, stacklevel=2) + + +class _VllmLogger(Logger): + """ + Note: + This class is just to provide type information. + We actually patch the methods directly on the [`logging.Logger`][] + instance to avoid conflicting with other libraries such as + `intel_extension_for_pytorch.utils._logger`. + """ + + def info_once(self, msg: str, *args: Hashable) -> None: + """ + As [`info`][logging.Logger.info], but subsequent calls with + the same message are silently dropped. + """ + _print_info_once(self, msg, *args) + + def warning_once(self, msg: str, *args: Hashable) -> None: + """ + As [`warning`][logging.Logger.warning], but subsequent calls with + the same message are silently dropped. + """ + _print_warning_once(self, msg, *args) + + +def _configure_vllm_root_logger() -> None: + logging_config = dict[str, Any]() + + if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: + raise RuntimeError( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " + "implies VLLM_CONFIGURE_LOGGING. Please enable " + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") + + if VLLM_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG + + if VLLM_LOGGING_CONFIG_PATH: + if not path.exists(VLLM_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + VLLM_LOGGING_CONFIG_PATH) + with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file: + custom_config = json.loads(file.read()) + + if not isinstance(custom_config, dict): + raise ValueError("Invalid logging config. Expected dict, got %s.", + type(custom_config).__name__) + logging_config = custom_config + + for formatter in logging_config.get("formatters", {}).values(): + # This provides backwards compatibility after #10134. + if formatter.get("class") == "vllm.logging.NewLineFormatter": + formatter["class"] = "vllm.logging_utils.NewLineFormatter" + + if logging_config: + dictConfig(logging_config) + + +def init_logger(name: str) -> _VllmLogger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root vllm logger has + already been configured.""" + + logger = logging.getLogger(name) + + methods_to_patch = { + "info_once": _print_info_once, + "warning_once": _print_warning_once, + } + + for method_name, method in methods_to_patch.items(): + setattr(logger, method_name, MethodType(method, logger)) + + return cast(_VllmLogger, logger) + + +# The root logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_configure_vllm_root_logger() + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ['call', 'return']: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the vllm root_dir + return + # Log every function call or return + try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" + with open(log_path, 'a') as f: + ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + if event == 'call': + f.write(f"{ts} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + else: + f.write(f"{ts} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, + root_dir: Optional[str] = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + vllm root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "VLLM_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only.") + logger.info("Trace frame log is saved to %s", log_file_path) + if root_dir is None: + # by default, this is the vllm root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) diff --git a/vllm/logging_utils/__init__.py b/vllm/logging_utils/__init__.py new file mode 100644 index 0000000..cf690a8 --- /dev/null +++ b/vllm/logging_utils/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.logging_utils.formatter import NewLineFormatter + +__all__ = [ + "NewLineFormatter", +] diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py new file mode 100644 index 0000000..ad89638 --- /dev/null +++ b/vllm/logging_utils/dump_input.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import contextlib +import enum +import json +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.metrics.stats import SchedulerStats +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + + +def prepare_object_to_dump(obj) -> str: + if isinstance(obj, str): + return f"'{obj}'" # Double quotes + elif isinstance(obj, dict): + dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ + for k, v in obj.items()}) + return f'{{{dict_str}}}' + elif isinstance(obj, list): + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" + elif isinstance(obj, set): + return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]" + # return [prepare_object_to_dump(v) for v in list(obj)] + elif isinstance(obj, tuple): + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" + elif isinstance(obj, enum.Enum): + return repr(obj) + elif isinstance(obj, torch.Tensor): + # We only print the 'draft' of the tensor to not expose sensitive data + # and to get some metadata in case of CUDA runtime crashed + return (f"Tensor(shape={obj.shape}, " + f"device={obj.device}," + f"dtype={obj.dtype})") + elif hasattr(obj, 'anon_repr'): + return obj.anon_repr() + elif hasattr(obj, '__dict__'): + items = obj.__dict__.items() + dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \ + for k, v in items]) + return f"{type(obj).__name__}({dict_str})" + else: + # Hacky way to make sure we can serialize the object in JSON format + try: + return json.dumps(obj) + except (TypeError, OverflowError): + return repr(obj) + + +def dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + # NOTE: ensure we can log extra info without risking raises + # unexpected errors during logging + with contextlib.suppress(Exception): + _dump_engine_exception(config, scheduler_output, scheduler_stats) + + +def _dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + logger.error( + "Dumping input data for V1 LLM engine (v%s) with config: %s, ", + VLLM_VERSION, + config, + ) + try: + dump_obj = prepare_object_to_dump(scheduler_output) + logger.error("Dumping scheduler output for model execution: %s", + dump_obj) + if scheduler_stats: + logger.error("Dumping scheduler stats: %s", scheduler_stats) + except Exception: + logger.exception("Error preparing object to dump") diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py new file mode 100644 index 0000000..0affef1 --- /dev/null +++ b/vllm/logging_utils/formatter.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg diff --git a/vllm/logits_process.py b/vllm/logits_process.py new file mode 100644 index 0000000..5967d08 --- /dev/null +++ b/vllm/logits_process.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Union + +import torch + +from vllm.transformers_utils.tokenizer import AnyTokenizer + +LogitsProcessor = Union[ + Callable[[list[int], torch.Tensor], torch.Tensor], + Callable[[list[int], list[int], torch.Tensor], torch.Tensor], +] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" + + +def get_bad_words_logits_processors( + bad_words: list[str], + tokenizer: AnyTokenizer) -> list[LogitsProcessor]: + bad_words_ids: list[list[int]] = list() + + for bad_word in bad_words: + # To prohibit words both at the beginning + # and in the middle of text + # (related to add_prefix_space tokenizer parameter) + for add_prefix_space in [False, True]: + prefix = " " if add_prefix_space else "" + prompt = prefix + bad_word.lstrip() + + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) + + # If no space at the beginning + # or if prefix space produces a new word token + if (not add_prefix_space) or ( + add_prefix_space + and prompt_token_ids[0] != bad_words_ids[-1][0] + and len(prompt_token_ids) == len(bad_words_ids[-1])): + bad_words_ids.append(prompt_token_ids) + + return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] + + +class NoBadWordsLogitsProcessor: + _SMALLEST_LOGIT = float("-inf") + _NEUTRAL_LOGIT = 0.0 + + def __init__(self, bad_words_ids: list[list[int]]): + self.bad_words_ids = bad_words_ids + self.word_bias: torch.FloatTensor = None + + def __call__( + self, + past_tokens_ids: Union[list[int], tuple[int]], + logits: torch.FloatTensor, + ) -> torch.Tensor: + if self.word_bias is None: + self._init_word_bias(logits=logits) + + last_token_bias = torch.zeros_like(logits) + + for bad_word_ids in self.bad_words_ids: + if len(bad_word_ids) == 1: # 1-token words already processed + continue + + if len(bad_word_ids) > len(past_tokens_ids) + 1: + continue + + prefix_length = len(bad_word_ids) - 1 + last_token_id = bad_word_ids[-1] + actual_prefix = past_tokens_ids[-prefix_length:] + expected_prefix = bad_word_ids[:prefix_length] + + assert len(actual_prefix) == len(expected_prefix) + + is_match = tuple(actual_prefix) == tuple(expected_prefix) + last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match + else self._NEUTRAL_LOGIT) + + logits = logits + self.word_bias + last_token_bias + + return logits + + def _init_word_bias(self, logits: torch.FloatTensor) -> None: + # Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501 + # from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py + + vocab_size = logits.shape[-1] + + self._check_token_ids_bounds(vocab_size=vocab_size) + + self.word_bias = torch.zeros((vocab_size, ), + dtype=torch.float, + device=logits.device) + + for bad_word_ids in self.bad_words_ids: + if len(bad_word_ids) == 1: + bad_word_id = bad_word_ids[-1] + self.word_bias[bad_word_id] = self._SMALLEST_LOGIT + + def _check_token_ids_bounds(self, vocab_size: int) -> None: + invalid_token_ids = [] + + for bad_word_ids in self.bad_words_ids: + for token_id in bad_word_ids: + if token_id < 0 or token_id >= vocab_size: + invalid_token_ids.append(token_id) + + if len(invalid_token_ids) > 0: + raise ValueError( + f"The model vocabulary size is {vocab_size}," + f" but the following tokens" + f" were specified as bad: {invalid_token_ids}." + f" All token id values should be integers satisfying:" + f" 0 <= token_id < {vocab_size}.") diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 0000000..7fc4cfe --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, Optional, Union, cast + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA) +from vllm.platforms import current_platform + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec + + +def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + + buffers = tensor_model_parallel_all_gather(buffers) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + #NOTE: lora_a contains 2 subloras, and each sublora could be None. + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[0] is not None else None, + lora_a[1][:, output_start_idx:output_start_idx + + output_shard_size] if lora_a[1] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): + """ + Differs from QKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked[0].shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): + """ + Differs from MergedQKVParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + # NOTE: lora_a contains 3 subloras, and each sublora could be None. + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + + shard_size[0]] if lora_a[0] is not None else None, + lora_a[1][:, start_idx[1]:start_idx[1] + + shard_size[1]] if lora_a[1] is not None else None, + lora_a[2][:, start_idx[2]:start_idx[2] + + shard_size[2]] if lora_a[2] is not None else None, + ] + return lora_a + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) + + if not current_platform.can_update_inplace(): + output = lora_output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000..3d0c583 --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,1285 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.distributed.utils import divide +# yapf: disable +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.lora.punica_wrapper import PunicaWrapperBase + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + # HQQ marlin + elif hasattr(base_layer, "W_q"): + return base_layer.W_q.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper, + ): + self.punica_wrapper: PunicaWrapperBase = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + embeddings_indices = torch.narrow( + self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) + + indices = embeddings_indices[1] + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = embeddings_indices[0] + full_output = self.base_layer.forward(x + + (indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + @property + def weight(self): + return self.base_layer.weight + + +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: LinearBase): + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None + + self.output_slices: tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_out_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) + + def reset_lora(self, index: int): + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + # Except for QKVParallelLinearWithLoRA and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) + + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + + return output + + @property + def weight(self) -> torch.Tensor: + + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None + + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + # ReplicatedLinear should always be replaced, regardless of the fully + # sharded LoRAs setting, because it is, by definition, copied per GPU. + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__(base_layer) + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size = self.base_layer.output_size_per_partition + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.base_layer.return_bias: + return output + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: + super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ + self.lora_config = lora_config + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + if lora_config.bias_enabled: + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for output_size in self.output_slices) + + def slice_lora_a( + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] + return lora_b + + def slice_bias( + self, bias: list[Union[torch.Tensor, + None]]) -> list[Union[torch.Tensor, None]]: + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] + return bias + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + lora_bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: list, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +#TODO: Implement this +class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): + pass + + +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + + self.tp_rank = get_tensor_model_parallel_rank() + # There is only one LoRA layer. + self.n_slices = 1 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + + shard_size = self.input_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + + def forward( + self, input_: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[self.tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[list[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_all_gather(self): + return self.base_layer.use_all_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, + :embeddings_tensor.shape[0], + :embeddings_tensor.shape[1], + ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + + # Gather logits for TP + logits = self.base_layer._gather_logits(logits) + + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu(): + indices_padded = indices_padded[:logits.size(0)] + + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + + # HPU needs special handling to prune out dummy samples. + if current_platform.is_hpu(): + lora_logits = lora_logits[:logits.shape[0], :] + + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False + + +class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialized kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + ... + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.punica_wrapper.long_lora_indices, + ) + + @property + def scaling_factor_to_offset(self) -> dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000..958364f --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence as GenericSequence +from typing import Optional + +import torch +import torch.types + +from vllm.lora.peft_helper import PEFTHelper +from vllm.utils import is_pin_memory_available + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + bias: Optional[torch.Tensor] = None, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.bias = bias + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return self + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def from_config( + cls, + module_name: str, + peft_helper: PEFTHelper, + embeddings_tensor: Optional[torch.Tensor] = None, + ) -> "LoRALayerWeights": + return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, + None, None, embeddings_tensor, + peft_helper.vllm_lora_scaling_factor) + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.types.Device, + embeddings_tensor_dim: Optional[int] = None, + bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and is_pin_memory_available() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + if bias_enabled: + bias = torch.zeros([output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + else: + bias = None + + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + bias=bias, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: list[Optional[int]], + lora_a: list[Optional[torch.Tensor]], + lora_b: list[Optional[torch.Tensor]], + bias: Optional[list[Optional[torch.Tensor]]] = None, + scaling: Optional[list[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + bias=bias, + scaling=scaling, # type: ignore + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack( + cls, loras: GenericSequence[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + [lora.bias if lora is not None else None for lora in loras], + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore + continue + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000..f6261b9 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,820 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +import os +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any, Callable, Optional, Union + +import regex as re +import safetensors.torch +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLoRA, + LoRAMapping) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + get_supported_lora_modules, + is_regex_target_modules, + parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models import SupportsLoRA, supports_multimodal +from vllm.model_executor.models.interfaces import is_pooling_model +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.model_executor.utils import get_packed_modules_mapping +from vllm.utils import is_pin_memory_available + +logger = init_logger(__name__) + +_GLOBAL_LORA_ID = 0 + + +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: list[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: dict[int, int] = field(default_factory=dict) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel(AdapterModel): + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, + ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ + self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor + assert ( + lora_model_id + > 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: dict[str, LoRALayerWeights] = loras + + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + def check_lora_name(self, lora_name: str) -> bool: + return lora_name in self.loras + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + tensors: dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + tensor_name, weights_mapper) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + assert embedding_modules is not None + embeddings_module = next( + (k for k in embedding_modules if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper, lora_embeddings_tensor) + + if is_bias: + loras[module_name].bias = tensor.to(device=device, + dtype=dtype).t() + bias = tensor.to(device=device, dtype=dtype).t() + if pin_memory: + bias = bias.pin_memory() + loras[module_name].bias = bias + elif is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + assert embedding_padding_modules is not None + if any(name in module_name + for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + + return cls(lora_model_id, + peft_helper.r, + loras, + scaling_factor=peft_helper.vllm_long_context_scaling_factor) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: list[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: LoRA model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + tensors: dict[str, torch.Tensor] = {} + unexpected_modules: list[Union[list[str], str]] = [] + + def check_unexpected_modules(modules: dict): + for lora_module in modules.keys(): # noqa + module_name, _, _ = parse_fine_tuned_lora_name( + lora_module, weights_mapper) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + + if tensorizer_config_dict: + from tensorizer import TensorDeserializer + + tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, + "adapter_model.tensors") + tensorizer_args = tensorizer_config._construct_tensorizer_args() + tensors = TensorDeserializer(lora_tensor_path, + dtype=tensorizer_config.dtype, + **tensorizer_args.deserializer_params) + check_unexpected_modules(tensors) + + elif os.path.isfile(lora_tensor_path): + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, + framework="pt") as f: # type: ignore + # Load tensors if there are only expected modules. + check_unexpected_modules(f) + for module in f.keys(): # noqa + tensors[module] = f.get_tensor(module) + elif os.path.isfile(lora_bin_file_path): + # When a bin file is provided, we rely on config to find unexpected + # modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + tensors = torch.load(lora_bin_file_path, + map_location=device, + weights_only=True) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device, + weights_only=True) + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + weights_mapper=weights_mapper) + + +class LoRAModelManager(AdapterModelManager): + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: SupportsLoRA, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + """ + self.lora_config = lora_config + self.device = device + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None + self.punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=self.lora_config.max_loras) + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: dict[float, int] = {} + super().__init__(model) + + self.supported_lora_modules = get_supported_lora_modules(self.model) + assert self.supported_lora_modules, "No supported LoRA modules found in" + f" {self.model.__class__.__name__}." + if lora_config.lora_target_modules is not None: + self.supported_lora_modules = lora_config.lora_target_modules + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") + + self.packed_modules_mapping = get_packed_modules_mapping(self.model) + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = ( + supports_multimodal(self.model) + # In case the model only supports LoRA for + # text modules (e.g. ChatGLM) + and hasattr(self.model, "get_mm_mapping")) + self.is_pooling_model = is_pooling_model(self.model) + self.packed_modules: dict[str, list[str]] = {} + self.modules: dict[str, BaseLayerWithLoRA] = {} + # Dict instead of a set for compatibility with LRUCache. + self._last_mapping: Optional[LoRAMapping] = None + self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRA' + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + @property + def adapter_slots(self) -> int: + return self.lora_slots + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_adapters: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] + logger.debug("Activating LoRA. int id: %d, slot index: %d", + lora_model.id, index) + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = self._get_lora_layer_weights(lora_model, module_name) + if module_lora: + module_lora.optimize() + # Bias is not explicitly enabled with the flag enable_lora_bias. + bias = module_lora.bias + if ((torch.is_tensor(bias) or + (isinstance(bias, Sequence) and any(b is not None + for b in bias))) + and not self.lora_config.bias_enabled): + module_lora.bias = None + raise ValueError( + f"Adapter bias cannot be used for {module_name}" + " without --enable-lora-bias.") + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor, + module_lora.bias) + else: + module.reset_lora(index) + return True + + def _deactivate_adapter(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + + def _add_adapter(self, lora: LoRAModel): + self._create_merged_loras_inplace(lora) + self._registered_adapters[lora.id] = lora + self._set_long_lora_context(lora) + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in LoRAModelManager. " + "Use LRUCacheLoRAModelManager for pinning") # type: ignore + + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: + # update lora states + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) + + def remove_all_adapters(self): + """Remove all LoRAModels from the manager.""" + self._registered_adapters.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_adapters.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if isinstance(module, PPMissingLayer): + continue + if not self._match_target_modules(module_name): + continue + # A temporary approach for multimodal models to support LoRA + # TODO: Remove this restriction + if self._filter_unsupported_mm_module(module_name): + logger.warning( + "Regarding multimodal models, vLLM currently only supports " + "adding LoRA to language model, %s will be ignored.", + module_name, + ) + continue + parts = module_name.split(".")[-1] + packed_moduled_lst = self.packed_modules_mapping.get(parts, []) + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + packed_moduled_lst, self.model.config)) + + # LinearScalingRotaryEmbeddingWithLoRA is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset + # (yard1): TODO make this more robust + if "lm_head" in module_name: + logits_processor_module = self.model.get_submodule( + "logits_processor") + new_module = replace_submodule( + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) + + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error + if self.supports_mm and not isinstance(new_module, + BaseLayerWithLoRA): + continue + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + # All lora layers share the same punica_wrapper based on reference. + new_module.set_mapping(self.punica_wrapper) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora( + self, + lora_id: int, + rank: int, + scaling_factor: Optional[float], + embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}, scaling_factor) + for module_name, module in self.model.named_modules(): + bias_enabled = self.lora_config.bias_enabled + if (not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA) + or self._filter_unsupported_mm_module(module_name)): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + assert embedding_modules is not None + if parts[-1] in embedding_modules: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked[0].dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim, + bias_enabled=bias_enabled) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked[0].shape[-1], + module.lora_b_stacked[0].shape[-2], + rank, + module.lora_a_stacked[0].dtype, + "cpu", + bias_enabled=bias_enabled, + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras: list[Optional[LoRALayerWeights]] = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + bias_enabled=bias_enabled, + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.supported_lora_modules) + + def _filter_unsupported_mm_module(self, module_name: str) -> bool: + """ + Regarding multimodal models, vLLM currently only supports adding LoRA to + language model. LoRA for other modules, such as the vision tower, will + be filtered out. + """ + if self.supports_mm: + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + prefix_lst = module_mapping.connector + module_mapping.tower_model + return any( + [module_name.startswith(prefix) for prefix in prefix_lst]) + return False + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name, []) + # When replacements is less than or equal to 1, it indicates that this + # module is not a packed module. + if len(replacements) <= 1: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras: list[Optional[LoRALayerWeights]] = [] + replaced_module: set[str] = set() + has_replacement = False + for r in new_module_names: + lora = self._get_lora_layer_weights(lora_model, r) + replacement_loras.append(lora) + if lora: + has_replacement = True + replaced_module.add(r) + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + # HACK Temporary solution for the pool model. + if self.is_pooling_model and not lora_model.check_lora_name( + module_name): + replaced_module_name = module_name.replace("model.", "") + if lora_model.check_lora_name(module_name): + module_name = replaced_module_name + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + # Remove the modules that have been replaced. + for module in replaced_module: + lora_model.loras.pop(module, None) + + def _get_lora_layer_weights( + self, lora_model: LoRAModel, + module_name: str) -> Optional[LoRALayerWeights]: + org_module_name = module_name + if self.is_pooling_model and not lora_model.check_lora_name( + module_name): + # If it's a pool model, and the layer name is not found, + # remove the prefix 'model.' and search again. + module_name = module_name.replace("model.", "") + if lora_model.check_lora_name(module_name): + org_module_name = module_name + logger.info_once( + "For the pool model, successfully loaded the LoRA weights " + "after removing the prefix 'model.'.") + return lora_model.get_lora(org_module_name) + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): + super().__init__(capacity, deactivate_lora_fn) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__(self, model: nn.Module, max_num_seqs: int, + max_num_batched_tokens: int, vocab_size: int, + lora_config: LoRAConfig, device: torch.device): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, device) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) + + def list_adapters(self) -> dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(lora.id) + was_added = False + return was_added + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(lora_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + self._pin_lora_in_cpu_cache(lora_id) + self._pin_lora_in_gpu_cache(lora_id) + return True + + def _pin_lora_in_cpu_cache(self, lora_id: int): + try: + self._registered_adapters.pin(lora_id) + except ValueError as err: + raise ValueError("Pinning failed. " + f"LoRA {lora_id} is not registered.") from err + + def _pin_lora_in_gpu_cache(self, lora_id: int): + if lora_id not in self._active_adapters: + # move lora to gpu if not already active + self.activate_adapter(lora_id) + + self._active_adapters.pin(lora_id) + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not isinstance(model, SupportsLoRA): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + device=device, + **kwargs) + return lora_manager diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/ops/torch_ops/__init__.py b/vllm/lora/ops/torch_ops/__init__.py new file mode 100644 index 0000000..22aa3c6 --- /dev/null +++ b/vllm/lora/ops/torch_ops/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py new file mode 100644 index 0000000..cba5baa --- /dev/null +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + # LoRA adapter and model may add different amounts of padding to output + common_len = min(outputs.shape[1], output_tensor.shape[1]) + + if add_inputs: + output_tensor[:, :common_len] += outputs[:limit, :common_len] + else: + output_tensor[:, :common_len] = outputs[:limit, :common_len] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py new file mode 100644 index 0000000..805de4b --- /dev/null +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand +from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta +from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink + +__all__ = [ + "lora_expand", + "lora_shrink", + "LoRAKernelMeta", +] diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py new file mode 100644 index 0000000..e93064d --- /dev/null +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utilities for Punica kernel construction. +""" +from vllm.triton_utils import tl, triton + + +@triton.jit +def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr): + """ + Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of + B (k x n), iterate, through the K dimension to compute the partial/complete + matrix block product. + If SPLIT_K == 1, the output m x n product is complete. + If SPLIT_K > 1, the thread block computes partial outputs. The partial + outputs are then atomically summed in the caller code. + Args: + a_ptr: Array of pointers, identifying rows of A + b_ptr: Array of pointers, identifying columns of B + ak_stride: K dimension stride of the A matrix + bk_stride: K dimension stride of the B matrix + K: Length of the K dimension + BLOCK_M: M dimension of the output block m x n + BLOCK_N: N dimension of the output block m x n + BLOCK_K: K dimension atom + EVEN_K: True if the blocks of A and B can be loaded without any + masking. + SPLIT_K: Parameter signifying parallelism in the K dimension. + CAST_TYPE: if True, cast the values from the A matrix to the B + matrix dtype. + b_dtype: datatype of the B matrix + """ + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] + < K - k * (BLOCK_K * SPLIT_K), + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * SPLIT_K * ak_stride + b_ptr += BLOCK_K * SPLIT_K * bk_stride + return accumulator + + +@triton.jit +def do_expand_kernel( + pid_n, + lora_index, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SAME_STRIDE: tl.constexpr, + SLICE_NUM: tl.constexpr, + EVEN_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + ADD_INPUTS: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, + compute the matrix product and store in the appropriate output location. + Given that this is an expand kernel, we don't perform any split-K reduction + as the K dimension is assumed to be small. + """ + + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + # Identify the input_ptr and lora_ptr from slice_id. + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = tl.arange(0, BLOCK_K) + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) + + # Compute the block matrix product. + SPLIT_K = 1 + accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, + offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, + CAST_TYPE, cur_lora_ptr.dtype.element_ty) + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = (out_ptr + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] + < (cur_slice_start + N)) + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@triton.jit +def do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_index, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + M_LEN, + ram, + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): + """ + Given an array of integers that identifies the rows of A, ram, + a lora index that identifies which LoRA to use from lora_ptr, lora_index, + a slice_id that identifies the input/output slice, compute the + matrix product and store in the appropriate output location. + """ + + # Identify the lora_ptr from slice_id. + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + + # Identify the column indices of B to process. + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + # Identify A and B block pointers + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + a_ptr = (input_ptr + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) + + # Compute partial/complete block matrix product. + accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, + K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, + cur_lora_ptr.dtype.element_ty) + + # Identify the C output pointers to store the results of the accumulator. + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cm = tl.arange(0, BLOCK_M) + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) + + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py new file mode 100644 index 0000000..9e1f90e --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel +from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_mn = tl.program_id(axis=0) + pid_m = pid_mn % cta_m_num + pid_n = (pid_mn // cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) + if pid_n * BLOCK_N >= curr_N: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_expand_kernel( + pid_n, + lora_id, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + curr_N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS) + + +@torch.inference_mode() +def _lora_expand( + inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + lora_b_weights: list[ + torch.Tensor], # shape [num_lora, hidden_size, lora_rank] + output_tensor: torch. + Tensor, # shape [num_tokens, hidden_size * num_slices] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (list[torch.Tensor]): lora'b weight + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(0) == len(lora_b_weights) + assert output_tensor.is_contiguous() + + # metadata sanity check. + M = inputs.size(1) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, + inputs.device) + + K = lora_b_weights[0].shape[-1] # K= rank + ADD_INPUTS = add_inputs + MAX_LORAS = lora_ids.size(0) + CAST_TYPE = False + NUM_SLICES = len(lora_b_weights) + + # Triton kernel configs. + BLOCK_M = 64 + BLOCK_N = 128 + BLOCK_K = 16 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + + EVEN_K = K % BLOCK_K == 0 # type: ignore + + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only a few input tokens require + # LoRA. This might not be the best in all cases. + grid = ( + triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks simply exit. + MAX_LORAS, + ) + + _lora_expand_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + MAX_N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + hidden_sizes_tensor, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + NUM_SLICES, + same_stride, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + ) + + return + + +def _lora_expand_fake( + inputs: torch.Tensor, + lora_b_weights: list[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_expand", + op_func=_lora_expand, + mutates_args=["output_tensor"], + fake_impl=_lora_expand_fake, + ) + lora_expand = torch.ops.vllm.lora_expand + +except AttributeError: + lora_expand = _lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py new file mode 100644 index 0000000..39e647b --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +LoRA kernels metadata preparation utilities. +""" + +from dataclasses import dataclass +from typing import Union + +import torch + + +@dataclass +class LoRAKernelMeta: + token_lora_mapping: torch.Tensor + token_indices_sorted_by_lora_ids: torch.Tensor + active_lora_ids: torch.Tensor + num_tokens_per_lora: torch.Tensor + lora_token_start_loc: torch.Tensor + + # The V1 architecture uses the traced torch.compile graphs to execute + # a forward pass. Things to note about this process, + # 1. The tracing infers all python scalar datatype objects into a constant + # value. + # 2. The tracing cannot handle dynamic control flow. (dynamic control flow + # is an experimental feature in pytorch) + # 3. The internals of torch.ops functions are not traced. + # We disguise the "no_lora" flag as a cpu tensor and leverage point number 3 + # to early exit from inside the lora_expand / lora_shrink torch operation. + no_lora_flag_cpu: torch.Tensor + + @staticmethod + def make(max_loras: int, max_num_tokens: int, + device: Union[torch.device, str]) -> "LoRAKernelMeta": + + token_lora_mapping = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + # +1 because "no-lora" is also a possibility + # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] + # is a possibility. + active_lora_ids = torch.empty(max_loras + 1, + dtype=torch.int32, + device=device) + + # using running example, [3, 10, 5, 2] is a possibility. + num_tokens_per_lora = torch.zeros(max_loras + 1, + dtype=torch.int32, + device=device) + + # +2 for this because, the first index is always 0. + # using running example, lora_token_start_loc + # is [0, 3, 13, 18, 20]. + lora_token_start_loc = torch.zeros(max_loras + 2, + dtype=torch.int32, + device=device) + + no_lora_flag_cpu = torch.tensor([False], + dtype=torch.bool, + device='cpu') + + return LoRAKernelMeta( + token_lora_mapping=token_lora_mapping, + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids=active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc=lora_token_start_loc, + no_lora_flag_cpu=no_lora_flag_cpu) + + def _reset(self): + self.active_lora_ids.fill_(-1) + self.num_tokens_per_lora.fill_(0) + self.lora_token_start_loc.fill_(0) + self.no_lora_flag_cpu.fill_(False) + + def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: + """ + Prepare kernel metadata tensors for the current forward pass. + + Args: + token_lora_tensor (torch.Tensor): Tensor containing lora indices + for each input token. + """ + + self._reset() + + # Check and record no-lora case. + no_lora = torch.all(token_lora_mapping == -1) + self.no_lora_flag_cpu[0] = no_lora + + if no_lora: + # Early exit. LoRA kernels will not be run. + return + + num_tokens = token_lora_mapping.size(0) + + # copy token lora mapping + self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, + non_blocking=True) + + # token_indices_sorted_by_lora_ids + _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, + stable=True) + # start gpu transfer + self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( + token_indices_sorted_by_lora_ids, non_blocking=True) + + # active_lora_ids, num_tokens_per_lora + lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, + sorted=True, + return_counts=True) + self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, + non_blocking=True) + self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True) + + # lora_token_start_loc + lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) + self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True) + + def meta_args( + self, token_nums: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: + """ + This function returns the kernel metadata required for the current + forward pass execution of the kernel. The function returns all the + metadata required by the kernel, in order, as a tuple, so it can be + unpacked directly during the lora_shrink/lora_expand function call. + + Args: + token_nums (int): Number of input tokens in the current forward + pass. + """ + return ( + self.token_lora_mapping[:token_nums], + self.token_indices_sorted_by_lora_ids[:token_nums], + self.num_tokens_per_lora, + self.lora_token_start_loc, + self.active_lora_ids, + self.no_lora_flag_cpu, + ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py new file mode 100644 index 0000000..3f9edfc --- /dev/null +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, + token_indices_sorted_by_lora_ids, num_tokens_per_lora, + lora_token_start_loc, lora_ids, scaling, + input_d0_stride, input_d1_stride, lora_d0_stride, + lora_d1_stride, lora_d2_stride, output_d0_stride, + output_d1_stride, output_d2_stride, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_sk_m_n = tl.program_id(axis=0) + pid_sk = pid_sk_m_n % SPLIT_K + pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num + pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM) + + +@torch.inference_mode() +def _lora_shrink( + inputs: torch.Tensor, # shape [num_tokens, hidden_size] + lora_a_weights: list[ + torch.Tensor], # shape [num_loras, lora_rank, hidden_size] + output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + no_lora_flag_cpu: torch.Tensor, # shape [1] + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): Input tensor + lora_a_weights (list[torch.Tensor]): LoRA weights + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates + if there are any requests that require LoRA. + scaling (float): Scaling factor. + """ + + assert no_lora_flag_cpu.numel() == 1 + if no_lora_flag_cpu.item(): + # None of the inputs require LoRA. + return + + assert inputs.dtype == lora_a_weights[0].dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(1) == lora_a_weights[0].size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + # metadata sanity check + M = inputs.size(0) + assert token_lora_mapping.size(0) == M + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank + NUM_SLICES = len(lora_a_weights) + MAX_LORAS = lora_ids.size(0) + + # Triton kernel configs + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 256 if M < 128 else 32 + SPLIT_K = 64 if M < 128 else 8 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only few of the input tokens + # require LoRA. This might not be the best in all cases. + grid = ( + SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks exit early. + MAX_LORAS, + ) + + _lora_shrink_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + NUM_SLICES, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + ) + + return + + +def _lora_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: list[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + no_lora_flag_cpu: torch.Tensor, + scaling: float, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_shrink", + op_func=_lora_shrink, + mutates_args=["output_tensor"], + fake_impl=_lora_shrink_fake, + ) + lora_shrink = torch.ops.vllm.lora_shrink + +except AttributeError: + lora_shrink = _lora_shrink diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py new file mode 100644 index 0000000..5857f7f --- /dev/null +++ b/vllm/lora/ops/triton_ops/utils.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} + + +def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): + """ + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) + + if values := _LORA_A_PTR_DICT.get(key): + return values + + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + if len(lora_a_weights) > 1: + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + else: + lora_ptr_tensor = lora_a_weights[0] + + if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1): + raise ValueError("All LoRA weights must have the same stride.") + + _LORA_A_PTR_DICT[key] = ( + lora_ptr_tensor, + lora_strides_d0[0], + lora_strides_d1[0], + lora_strides_d2[0], + ) + return _LORA_A_PTR_DICT.get(key) + + +def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, + device: torch.device): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + + """ + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if values := _LORA_B_PTR_DICT.get(key): + return values + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + hidden_sizes = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + hidden_sizes.append(lora_b_weight.size(1)) + + if len(lora_weights) > 1: + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + else: + slice_start_tensor = slice_offset_lst[0] + lora_ptr_tensor = lora_b_weight[0] + + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and + len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + hidden_sizes_tensor = hidden_sizes[0] + same_stride = True + + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device) + same_stride = False + # MAX_N is the maximum hidden size among all the lora_b weights + MAX_N = max(hidden_sizes) + _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, + lora_strides_d0_tensor, lora_strides_d1_tensor, + lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) + return _LORA_B_PTR_DICT.get(key) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py new file mode 100644 index 0000000..7e7c3c8 --- /dev/null +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) + +__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py new file mode 100644 index 0000000..9118f33 --- /dev/null +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import jax +import jax.numpy as jnp +import torch +import torch.nn.functional as F +import torch_xla.core.xla_builder as xb +from torch.library import impl +from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard + + +@jax.jit +def bgmv_jax(inputs, loras, idxs): + return jnp.einsum( + "td,tX,Xld->tl", + inputs, + jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype), + loras, + ) + + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") + + +@impl(XLA_LIB, "bgmv", "XLA") +def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + jax_import_guard() + return xb.call_jax(bgmv_jax, (inputs, loras, idxs)) + + +@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + _, L, _ = loras.shape + + return torch.empty((T, L), device=inputs.device) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if output_tensor.shape[1] > outputs.shape[1]: + outputs = F.pad(outputs, + (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) + + if add_inputs: + return output_tensor + outputs[:limit, :output_tensor.shape[1]] + else: + return outputs[:limit, :output_tensor.shape[1]] + + +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + output_tensor (torch.Tensor): (Unused) output tensor (placeholder). + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + scaling (float, optional): Scalar multiplier applied to the output. + """ + + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, + lora_indices_tensor) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + + outputs = F.pad( + outputs, + ( + slice_offset, + output_tensor.shape[1] - (slice_offset + slice_size), + 0, + 0, + ), + ) + + if add_inputs: + return output_tensor + outputs + else: + return outputs diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py new file mode 100644 index 0000000..a20d73f --- /dev/null +++ b/vllm/lora/peft_helper.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py + +import json +import math +import os +from dataclasses import MISSING, dataclass, field, fields +from typing import Literal, Optional, Union + +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + +logger = init_logger(__name__) + + +@dataclass +class PEFTHelper: + """ + A helper class for PEFT configurations, specifically designed for LoRA. + This class handles configuration validation, compatibility checks for + various LoRA implementations. + """ + + # Required fields + r: int + lora_alpha: int + target_modules: Union[list[str], str] + + bias: Literal["none", "all", "lora_only"] = field(default="none") + modules_to_save: Optional[list[str]] = field(default=None) + # True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732) + use_rslora: bool = field(default=False) + # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) + use_dora: bool = field(default=False) + # long context lora field + context_length: int = field(default=0) + # Extra vllm field, start with 'vllm_' to avoid conflict + vllm_lora_scaling_factor: float = field(default=1.0) + vllm_max_position_embeddings: Optional[int] = field(default=False) + vllm_long_context_scaling_factor: Optional[float] = field(default=None) + + def _validate_features(self) -> list[str]: + """ + Check if there are any unsupported LoRA features. + """ + error_msg = [] + if self.modules_to_save: + error_msg.append("vLLM only supports modules_to_save being None.") + if self.use_dora: + error_msg.append("vLLM does not yet support DoRA.") + return error_msg + + def __post_init__(self): + if self.use_rslora: + logger.info_once("Loading LoRA weights trained with rsLoRA.") + self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) + else: + self.vllm_lora_scaling_factor = self.lora_alpha / self.r + if self.context_length: + if self.vllm_max_position_embeddings is None: + self.vllm_max_position_embeddings = self.context_length + self.vllm_long_context_scaling_factor = float( + math.ceil(self.context_length / + self.vllm_max_position_embeddings)) + + @classmethod + def from_dict(cls, config_dict: dict) -> "PEFTHelper": + # Get all field information from the class + class_fields = {f.name: f for f in fields(cls)} + # Check for required fields + required_fields = { + name + for name, f in class_fields.items() + if f.default is MISSING and f.default_factory is MISSING + } + + # Identify any missing required fields + missing_fields = required_fields - set(config_dict.keys()) + if missing_fields: + raise ValueError( + f"Missing required configuration fields: {missing_fields}") + + # Filter out fields that aren't defined in the class + filtered_dict = { + k: v + for k, v in config_dict.items() if k in class_fields + } + return cls(**filtered_dict) + + @classmethod + def from_local_dir( + cls, + lora_path: str, + max_position_embeddings: Optional[int], + tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper": + lora_config_path = os.path.join(lora_path, "adapter_config.json") + + if tensorizer_config_dict: + tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + from tensorizer.stream_io import open_stream + lora_config_path = os.path.join(tensorizer_config.lora_dir, + "adapter_config.json") + with open_stream(lora_config_path, + mode="rb", + **tensorizer_args.stream_params) as f: + config = json.load(f) + + logger.info("Successfully deserialized LoRA config from %s", + tensorizer_config.lora_dir) + + else: + with open(lora_config_path) as f: + config = json.load(f) + + config["vllm_max_position_embeddings"] = max_position_embeddings + return cls.from_dict(config) + + def validate_legal(self, lora_config: LoRAConfig) -> None: + """ + Validates the LoRA configuration settings against application + constraints and requirements. + """ + error_msg = self._validate_features() + if self.r > lora_config.max_lora_rank: + error_msg.append( + f"LoRA rank {self.r} is greater than max_lora_rank" + f" {lora_config.max_lora_rank}.") + if self.bias != "none" and not lora_config.bias_enabled: + error_msg.append( + "Adapter bias cannot be used without bias_enabled.") + if error_msg: + raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/punica_wrapper/__init__.py b/vllm/lora/punica_wrapper/__init__.py new file mode 100644 index 0000000..e664ffa --- /dev/null +++ b/vllm/lora/punica_wrapper/__init__.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase +from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper + +__all__ = [ + "PunicaWrapperBase", + "get_punica_wrapper", +] diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py new file mode 100644 index 0000000..5b4902d --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from .utils import compute_meta, convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +class PunicaWrapperABC(ABC): + """ + PunicaWrapper ABC. + """ + + @abstractmethod + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs, + ) -> None: + """ + Update the lora-related metadata + """ + raise NotImplementedError + + @abstractmethod + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + """ + + raise NotImplementedError + + @abstractmethod + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> Optional[torch.Tensor]: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> Optional[torch.Tensor]: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA, + and this layer only requires the expand operation. + """ + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> Optional[torch.Tensor]: + """ + Applicable to linear-related lora. + """ + + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> Optional[torch.Tensor]: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + """ + raise NotImplementedError + + +class PunicaWrapperBase(PunicaWrapperABC): + """ + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indices tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: list[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.device: torch.device = device + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + self.device, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + @property + def prefill_metadata( + self + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLoRA. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metadata(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + @abstractmethod + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_expand(self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + offset = offset_start + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + bias's weight + output_slices (tuple[int, ...]): Every slice's size + offset_start (int): The starting position of y, defaults to 0 + add_inputs (bool): Defaults to True. + + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> Optional[torch.Tensor]: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + and this layer only requires the expand operation. + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> Optional[torch.Tensor]: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError + + @abstractmethod + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> Optional[torch.Tensor]: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + # TODO: implement it based on torch ops + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py new file mode 100644 index 0000000..59049cc --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -0,0 +1,349 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional, Union + +import torch + +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperCPU(PunicaWrapperBase): + """ + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + bias's weight + output_slices (tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py new file mode 100644 index 0000000..6b03830 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import TYPE_CHECKING, Optional, Union, final + +import torch + +import vllm.envs as envs +from vllm.lora.layers import LoRAMapping +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, + lora_shrink) + +from .punica_base import PunicaWrapperBase + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.models import LongContextLoRAContext + + +@final +class PunicaWrapperGPU(PunicaWrapperBase): + """ + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica triton kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + self.max_loras = kwargs['max_loras'] + + self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, + max_num_batched_tokens, + device=device) + + # When cudagraph capture size is greater than max_num_seqs (max_batches, + # here), V0 captures the graph as if max_num_seqs is set to + # the capture size. + # V1 doesn't have this problem and always respects max_num_seqs. + max_num_prompts = (max_batches + if envs.VLLM_USE_V1 else max_num_batched_tokens) + self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, + max_num_prompts, + device=device) + + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + self.is_prefill = mapping.is_prefill + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + + # Prepare cuda kernel metadata tensors + self.token_mapping_meta.prepare_tensors(self.token_lora_indices) + self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) + + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, + ...], scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (torch.Tensor): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + lora_shrink( + x, + lora_a_stacked, + y, + *self.token_mapping_meta.meta_args(x.size(0)), + scale, + ) + + def add_expand(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + bias's weight + output_slices (tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + if lora_bias_stacked is not None: + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, + y.size(0)) + self._apply_bias(token_lora_indices, y, output_slices, + lora_bias_stacked) + + assert x.ndim == 3 + assert x.size(0) == len(output_slices) + num_tokens = x.size(1) # first dimension is the num slices + + lora_expand( + x, + lora_b_stacked, + y, + *self.token_mapping_meta.meta_args(num_tokens), + offset_start=offset_start, + add_inputs=True, + ) + + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + lora_expand( + x.unsqueeze(dim=0), + (lora_b_stacked, ), + y, + *self.token_mapping_meta.meta_args(x.size(0)), + offset_start=0, + add_inputs=add_inputs, + ) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[torch.Tensor]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, + y.size(0)) + y = self._apply_bias(token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros( # type: ignore + (len(output_slices), x.size(0), r), + dtype=torch.float32, + device=x.device, + ) + self.add_shrink( + buffer, # type: ignore + x, + lora_a_stacked, + scale, + **kwargs) + self.add_expand( + y, + buffer, # type: ignore + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor): lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]): Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), + *self.prompt_mapping_meta.meta_args(x.size(0)), scale) + + lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], + y, + *self.prompt_mapping_meta.meta_args(buffer.size(0)), + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py new file mode 100644 index 0000000..b20c978 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional, Union, final + +import torch +from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, + dispatch_bgmv_linear) + +from .punica_base import PunicaWrapperBase +from .utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +@final +class PunicaWrapperHPU(PunicaWrapperBase): + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + # Increasing max_num_batched_tokens by 3x to handle increase in + # tensor size due to padding. + PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, + max_batches, device) + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, + extra_vocab_size, self.device, None) + # Updating each element in `long_lora_offsets` with `lora_offset` slows + # down perf in HPU due to a series of `strided_insert` ops during lazy + # graph accumulation. Hence HPU appends `lora_offset` to a list and + # converts it to a tensor only after it is ready. + if long_lora_context: + index_mapping_indices: list[int] = list( + mapping.index_mapping).copy() + long_lora_offsets: list[int] = [] + for i in range(len(index_mapping_indices)): + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets.append(lora_offset) + long_lora_offsets_tensor = torch.tensor(long_lora_offsets, + device=self.device, + dtype=torch.long) + indices_len[-1] = long_lora_offsets_tensor.shape[-1] + + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + for slice_idx in range(len(output_slices)): + dispatch_bgmv_linear( + y[:, offset_left:offset_left + output_slices[slice_idx]], x, + lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) + y = y.view_as(y_org) + + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> None: + raise NotImplementedError + + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: + raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py new file mode 100644 index 0000000..c684ac7 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname + +from .punica_base import PunicaWrapperBase + +logger = init_logger(__name__) + + +def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: + punica_wrapper_qualname = current_platform.get_punica_wrapper() + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) + punica_wrapper = punica_wrapper_cls(*args, **kwargs) + assert punica_wrapper is not None, \ + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1]) + return punica_wrapper diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py new file mode 100644 index 0000000..6b48268 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.nn.functional as F +import torch_xla.core.xla_model as xm + +from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.lora.punica_wrapper.utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + +from .punica_base import PunicaWrapperBase + + +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which + # isn't supported by the TPU. So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to( + dtype=torch.int32) + self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to( + dtype=torch.int32) + + torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, + True) + torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, + True) + + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + return self._embeddings_indices[:] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + + def shrink( + self, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) + + def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + add_inputs: bool): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), + add_inputs) + + def expand_slice(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + add_inputs: bool) -> torch.Tensor: + return bgmv_expand_slice(x, w_t_all, y, + self._get_token_lora_indices(x), y_offset, + y_slice_size, add_inputs) + + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + torch.ops.xla.dynamo_set_buffer_donor_(y, True) + x = x.view(-1, x.shape[-1]) + + for slice_idx in range(len(lora_a_stacked)): + lora_s = lora_a_stacked[slice_idx] + y_s = self.shrink(x, lora_s, scale) + y[slice_idx, :, :] = y_s # type: ignore[index] + return y + + def add_expand(self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> torch.Tensor: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + bias's weight + output_slices (tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + if lora_bias_stacked is not None: + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + y = self.expand_slice(y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs) + offset_left += output_slices[slice_idx] + return y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only needs the expand op + return self.expand(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs) -> torch.Tensor: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will not be changed in-place. + x (torch.Tensor): Input tensor (T, E) + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + T = x.size(0) + buffer = torch.zeros( + (len(output_slices), T, r), + dtype=x.dtype, + device=x.device, + ) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + sampler_indices, + add_inputs=True) + return y.view_as(y_org) + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias = torch.where(indices[:, None] == -1, 0, bias) + + bias = F.pad(bias, (offset_left, output.shape[1] - + (offset_left + slice), 0, 0)) + + output += bias + offset_left += slice + + return output.view_as(org_output) + + # This performs the same tensor ops as the base method, except it does them + # on the CPU then transfers the results to the TPU + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + # Make sure we don't accidentally collect outside operations + xm.mark_step() + + # Pad the prompt mapping to avoid running into recompiles on the TPU + # TODO: Should this happen inside mapping internally? If so how can we + # avoid having backend specific LoRAMapping classes? + mapping.prompt_mapping = self._pad_prompt_mapping( + mapping.prompt_mapping) + + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + "cpu", + long_lora_context, + ) + self._token_lora_indices = self._pad_to_shape( + base_indices, self._token_lora_indices.shape, + dims=1).to(self.device) + self._sampler_indices = self._pad_to_shape(sampler_indices, + self._sampler_indices.shape, + dims=1).to(self.device) + self._sampler_indices_padded = self._pad_to_shape( + sampler_indices_padded, self._sampler_indices_padded.shape, + dims=1).to(self.device) + self._embeddings_indices = self._pad_to_shape( + embeddings_indices, self._embeddings_indices.shape, + dims=2).to(self.device) + if long_lora_offsets_tensor is not None: + self._long_lora_indices = self._pad_to_shape( + long_lora_offsets_tensor, + self._long_lora_indices.shape, + dims=1).to(self.device) + else: + zeroed = torch.zeros_like(self._long_lora_indices.cpu(), + dtype=torch.int32) + self._long_lora_indices = zeroed.to(self.device) + self.indices_len[:] = indices_len + + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self. + batch_size] = token_lora_tensor[:self. + batch_size] + + def _pad_prompt_mapping( + self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: + num_reqs = len(prompt_mapping) + + # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular + # import + MIN_NUM_SEQS = 8 + + padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + pad_len = padded_num_reqs - num_reqs + + padding = [-1] * pad_len + return tuple(list(prompt_mapping) + padding) + + def _pad_to_shape(self, src, target_shape, dims=1): + if dims == 1: + pad_len = target_shape[0] - src.shape[0] + return F.pad(src, (0, pad_len), value=0).to(torch.int32) + else: + pad_rows = target_shape[0] - src.shape[0] + pad_cols = target_shape[1] - src.shape[1] + return F.pad(src, (0, pad_cols, 0, pad_rows), + value=0).to(torch.int32) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py new file mode 100644 index 0000000..8430cb9 --- /dev/null +++ b/vllm/lora/punica_wrapper/utils.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional, Union + +import torch + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + device: torch.device, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], list[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indices. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indices, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: list[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=device, + dtype=torch.long) + prompt_mapping: list[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: list[Union[list[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) + prompt_mapping_tensor = torch.tensor(prompt_mapping, + dtype=torch.long, + device=device) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, + embeddings_indices) + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded = torch.where(sampler_indices_padded == -1, + max_loras - 1, sampler_indices_padded) + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000..f895dc2 --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import warnings +from typing import Optional + +import msgspec + +from vllm.adapter_commons.request import AdapterRequest + + +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """ + Request for a LoRA adapter. + + Note that this class should be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + __metaclass__ = AdapterRequest + + lora_name: str + lora_int_id: int + lora_path: str = "" + lora_local_path: Optional[str] = msgspec.field(default=None) + long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) + tensorizer_config_dict: Optional[dict] = None + + def __post_init__(self): + if self.lora_local_path: + warnings.warn( + "The 'lora_local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'lora_path' instead.", + DeprecationWarning, + stacklevel=2) + if not self.lora_path: + self.lora_path = self.lora_local_path or "" + + # Ensure lora_path is not empty + assert self.lora_path, "lora_path cannot be empty" + + @property + def adapter_id(self): + return self.lora_int_id + + @property + def name(self): + return self.lora_name + + @property + def path(self): + return self.lora_path + + @property + def local_path(self): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + return self.lora_path + + @local_path.setter + def local_path(self, value): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name and \ + self.lora_int_id == value.lora_int_id and \ + self.lora_path == value.lora_path + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py new file mode 100644 index 0000000..5808ae1 --- /dev/null +++ b/vllm/lora/resolver.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import Set +from dataclasses import dataclass, field +from typing import Optional + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest + +logger = init_logger(__name__) + + +class LoRAResolver(ABC): + """Base class for LoRA adapter resolvers. + + This class defines the interface for resolving and fetching LoRA adapters. + Implementations of this class should handle the logic for locating and + downloading LoRA adapters from various sources (e.g. S3, cloud storage, + etc.). + """ + + @abstractmethod + async def resolve_lora(self, base_model_name: str, + lora_name: str) -> Optional[LoRARequest]: + """Abstract method to resolve and fetch a LoRA model adapter. + + Implements logic to locate and download LoRA adapter based on the name. + Implementations might fetch from a blob storage or other sources. + + Args: + base_model_name: The name/identifier of the base model to resolve. + lora_name: The name/identifier of the LoRA model to resolve. + + Returns: + Optional[LoRARequest]: The resolved LoRA model information, or None + if the LoRA model cannot be found. + """ + pass + + +@dataclass +class _LoRAResolverRegistry: + resolvers: dict[str, LoRAResolver] = field(default_factory=dict) + + def get_supported_resolvers(self) -> Set[str]: + """Get all registered resolver names.""" + return self.resolvers.keys() + + def register_resolver( + self, + resolver_name: str, + resolver: LoRAResolver, + ) -> None: + """Register a LoRA resolver. + Args: + resolver_name: Name to register the resolver under. + resolver: The LoRA resolver instance to register. + """ + if resolver_name in self.resolvers: + logger.warning( + "LoRA resolver %s is already registered, and will be " + "overwritten by the new resolver instance %s.", resolver_name, + resolver) + + self.resolvers[resolver_name] = resolver + + def get_resolver(self, resolver_name: str) -> LoRAResolver: + """Get a registered resolver instance by name. + Args: + resolver_name: Name of the resolver to get. + Returns: + The resolver instance. + Raises: + KeyError: If the resolver is not found in the registry. + """ + if resolver_name not in self.resolvers: + raise KeyError( + f"LoRA resolver '{resolver_name}' not found. " + f"Available resolvers: {list(self.resolvers.keys())}") + return self.resolvers[resolver_name] + + +LoRAResolverRegistry = _LoRAResolverRegistry() diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000..ee196e3 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional, Union + +import huggingface_hub +import regex as re +from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, + HFValidationError, RepositoryNotFoundError) +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +from vllm.model_executor.layers.linear import LinearBase +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +_all_lora_classes: set[type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLoRA, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLoRA, +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + instance_layer = lora_cls(layer) + instance_layer.create_lora_weights(max_loras, lora_config, + model_config) + return instance_layer + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device, + lm_head.get_sharded_to_full_mapping()) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name( + name: str, + weights_mapper: Optional[WeightsMapper] = None +) -> tuple[str, bool, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + weights_mapper: maps the name of weight, e.g. + `model.` -> `language_model.model.`, + return: + tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + is_bias whether the tensor is lora bias. + """ + + # LoRA weight qualified name usually starts with `base_model.model.`, + # so we remove the prefix `base_model.model.` to make the following + # mapping correctly. + if name.startswith("base_model.model."): + name = name.replace("base_model.model.", "") + name = weights_mapper._map_name(name) if weights_mapper else name + # recover the prefix `base_model.model.` + name = "base_model.model." + name + else: + name = weights_mapper._map_name(name) if weights_mapper else name + + # In some situations, we may not start with `base_model.model.`. + # If we don't (e.g., ibm-granite/granite-speech-3.3-8b), + # we should keep the prefix intact. + start_index = 2 if name.startswith("base_model.model.") else 0 + + parts = name.split(".") + if parts[-1] == "weight" and (parts[-2] == "lora_A" + or parts[-2] == "lora_B"): + new_name = ".".join(parts[start_index:-2]) + return new_name, parts[-2] == "lora_A", False + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + new_name = ".".join(parts[start_index:-1]) + return new_name, parts[-1] == "lora_embedding_A", False + + if parts[-1] == "bias": + new_name = ".".join(parts[start_index:-2]) + return new_name, False, True + + raise ValueError(f"{name} is unsupported LoRA weight") + + +def is_regex_target_modules(load_modules: Union[str, list[str]], + expected_lora_modules: list[str]) -> bool: + """ + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the + `expected_lora_modules`. + """ + + def is_valid_regex(pattern): + try: + re.compile(pattern) + return True + except re.error: + return False + + def is_subset(sub_list, full_list): + return set(sub_list).issubset(set(full_list)) + + # Similar to PEFT's processing logic, regex-related operations are only + # executed when the load_modules is a `str`. + if not isinstance(load_modules, str): + return False + + if is_valid_regex(load_modules): + match = re.search(r"\((.*?)\)\$?$", load_modules) + if match: + suffix = match.group(1).split("|") + return is_subset(suffix, expected_lora_modules) + return False + + +def get_supported_lora_modules(model: nn.Module) -> list[str]: + """ + In vLLM, all linear layers support LoRA. + """ + supported_lora_modules: set[str] = set() + # step1: traverse the model to get all the linear subfixes. + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + supported_lora_modules.add(name.split(".")[-1]) + # step 2: get the embedding modules if the model's mbedding_modules + # is not empty. + if model.embedding_modules: + for name in model.embedding_modules: + supported_lora_modules.add(name) + return list(supported_lora_modules) + + +def get_adapter_absolute_path(lora_path: str) -> str: + """ + Resolves the given lora_path to an absolute local path. + + If the lora_path is identified as a Hugging Face model identifier, + it will download the model and return the local snapshot path. + Otherwise, it treats the lora_path as a local file path and + converts it to an absolute path. + + Parameters: + lora_path (str): The path to the lora model, which can be an absolute path, + a relative path, or a Hugging Face model identifier. + + Returns: + str: The resolved absolute local path to the lora model. + """ + + # Check if the path is an absolute path. Return it no matter exists or not. + if os.path.isabs(lora_path): + return lora_path + + # If the path starts with ~, expand the user home directory. + if lora_path.startswith('~'): + return os.path.expanduser(lora_path) + + # Check if the expanded relative path exists locally. + if os.path.exists(lora_path): + return os.path.abspath(lora_path) + + # If the path does not exist locally, assume it's a Hugging Face repo. + try: + local_snapshot_path = huggingface_hub.snapshot_download( + repo_id=lora_path) + except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, + HFValidationError): + # Handle errors that may occur during the download + # Return original path instead instead of throwing error here + logger.exception("Error downloading the HuggingFace model") + return lora_path + + return local_snapshot_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000..7a4af74 --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Any, Literal, Optional, Union + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.models import (LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path + +logger = init_logger(__name__) + + +class WorkerLoRAManager(AbstractWorkerManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _manager_cls: type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + embedding_modules: dict[str, str], + embedding_padding_modules: list[str], + lora_model_cls: type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, + ): + self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) + # Lazily initialized by create_lora_manager. + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + lora_manager_cls=self._manager_cls, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: + try: + supported_lora_modules = ( + self._adapter_manager.supported_lora_modules) + packed_modules_mapping = ( + self._adapter_manager.packed_modules_mapping) + expected_lora_modules: list[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + expected_lora_modules = list(set(expected_lora_modules)) + lora_path = get_adapter_absolute_path(lora_request.lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, self.max_position_embeddings, + lora_request.tensorizer_config_dict) + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + model = self._adapter_manager.model + hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None) + + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + tensorizer_config_dict=lora_request.tensorizer_config_dict, + weights_mapper=hf_to_vllm_mapper) + + except FileNotFoundError as e: + # FileNotFoundError should be raised if both + # - No adapter found to download from huggingface (or in + # offline mode) + # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError + raise ValueError( + f"Loading lora {lora_request.lora_name} failed: No adapter " + f"found for {lora_request.lora_path}") from e + except Exception as e: + # For BadRequestError + raise e + + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_adapters(): + return False + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._adapter_manager.create_dummy_lora( + lora_request.lora_int_id, rank, 1, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._adapter_manager.add_adapter(dummy_lora) + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + lora_manager_cls=self._manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._adapter_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._adapter_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_adapter(lora) + + def add_adapter(self, lora_request: LoRARequest) -> bool: + # Note that this method is not thread-safe. It may be invoked multiple + # times for the same adapter when using multiple API servers. + # This is ok because it's currently only called from + # the single-threaded core engine loop. + + if lora_request.lora_int_id not in self.list_adapters(): + # Load the new adapter first to ensure it is actually valid, before + # evicting any existing adapters. + # This may cause the # of loaded lora adapters to very temporarily + # exceed `--max-cpu-loras`. + lora = self._load_adapter(lora_request) + + # Loading succeeded, now check if we will exceed cache capacity and + # evict if the oldest adapter if so + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + # Then add the new adapter to the cache + loaded = self._adapter_manager.add_adapter(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + lora_request.lora_int_id) is not None + self._adapter_manager.activate_adapter(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py new file mode 100644 index 0000000..55dfe80 --- /dev/null +++ b/vllm/model_executor/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.parameter import (BasevLLMParameter, + PackedvLLMParameter) +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingMetadataCache) +from vllm.model_executor.utils import set_random_seed + +__all__ = [ + "SamplingMetadata", + "SamplingMetadataCache", + "set_random_seed", + "BasevLLMParameter", + "PackedvLLMParameter", +] diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py new file mode 100644 index 0000000..9c88721 --- /dev/null +++ b/vllm/model_executor/custom_op.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch.nn as nn + +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ + + def __new__(cls, *args, **kwargs): + try: + op_name = cls.__name__ + except AttributeError: + raise TypeError( + f"Cannot instantiate '{cls.__name__}': its 'name' attribute " + f"was not set, possibly because it was not decorated with " + f"@CustomOp.register, or it's the CustomOp base class itself." + ) from None + + if op_name not in cls.op_registry_oot: + op_cls_to_instantiate = cls + else: + op_cls_to_instantiate = cls.op_registry_oot[op_name] + logger.debug("Instantiating custom op: %s using %s", op_name, + str(op_cls_to_instantiate)) + return super().__new__(op_cls_to_instantiate) + + def __init__(self): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + """PyTorch-native implementation of the forward method. + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + # By default, we assume that HIP ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_xpu(self, *args, **kwargs): + # By default, we assume that XPU ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs): + # By default, we assume that TPU ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def forward_hpu(self, *args, **kwargs): + # By default, we assume that Gaudi ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_neuron(self, *args, **kwargs): + # By default, we assume that Neuron ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_oot(self, *args, **kwargs): + # By default, we assume that OOT ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + # NOTE(woosuk): Here we assume that vLLM was built for only one + # specific backend. Currently, we do not support dynamic dispatching. + compilation_config = get_current_vllm_config().compilation_config + enabled = self.enabled() + if enabled: + compilation_config.enabled_custom_ops.update([self.__class__.name]) + else: + compilation_config.disabled_custom_ops.update( + [self.__class__.name]) + + if not enabled: + return self.forward_native + + if current_platform.is_rocm(): + return self.forward_hip + elif current_platform.is_cpu(): + return self.forward_cpu + elif current_platform.is_hpu(): + return self.forward_hpu + elif current_platform.is_tpu(): + return self.forward_tpu + elif current_platform.is_xpu(): + return self.forward_xpu + elif current_platform.is_neuron(): + return self.forward_neuron + elif current_platform.is_out_of_tree(): + return self.forward_oot + else: + return self.forward_cuda + + @classmethod + def enabled(cls) -> bool: + # if no name, then it was not registered + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + if not hasattr(cls, "name"): + logger.warning_once( + "Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.", # noqa: E501 + cls.__name__, + ) + return CustomOp.default_on() + + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops + assert not (enabled + and disabled), f"Cannot enable and disable {cls.name}" + + return (CustomOp.default_on() or enabled) and not disabled + + @staticmethod + def default_on() -> bool: + """ + On by default if PyTorch Inductor is not used. + Specifying 'all' or 'none' in custom_op takes precedence. + """ + from vllm.config import CompilationLevel + compilation_config = get_current_vllm_config().compilation_config + default_on = (compilation_config.level < CompilationLevel.PIECEWISE + or not compilation_config.use_inductor) + count_none = compilation_config.custom_ops.count("none") + count_all = compilation_config.custom_ops.count("all") + return default_on and not count_none > 0 or count_all > 0 + + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: dict[str, type['CustomOp']] = {} + op_registry_oot: dict[str, type['CustomOp']] = {} + + # Decorator to register custom ops. + @classmethod + def register(cls, name: str): + + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator + + # Decorator to register out-of-tree(oot) custom ops. + # For OOT custom ops: + # if in-tree layer class is registered with an oot_custom_op layer, + # the oot_custom_op layer will be used instead. + # Example: + # - @UnquantizedFusedMoEMethod.register_oot + # class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod) + # or + # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod") + @classmethod + def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None): + + def decorator(op_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in cls.op_registry_oot, \ + f"Duplicate op name: {reg_name}" + op_cls.name = reg_name + cls.op_registry_oot[reg_name] = op_cls + return op_cls + + if _decorated_op_cls is None: + # Called with parentheses: @CustomOP.register_oot() + # or @CustomOP.register_oot(name="...") + # So, _decorated_op_cls is None. + # We return the actual decorator function. + return decorator + elif isinstance(_decorated_op_cls, type): # Check if it's a class + # Called without parentheses: @CustomOP.register_oot + # The first argument is the class itself. + # We call the 'decorator' function immediately with the class. + return decorator(_decorated_op_cls) + else: + # Handle other unexpected cases if necessary + raise TypeError("Decorator can only be applied to classes.") diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000..3c2998b --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.model_executor.guided_decoding.utils import ( + convert_lark_to_gbnf, grammar_is_likely_lark, + has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) +from vllm.reasoning import ReasoningParserManager + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.logits_process import LogitsProcessor + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def maybe_backend_fallback( + guided_params: GuidedDecodingParams) -> GuidedDecodingParams: + + def fallback_or_error(guided_params: GuidedDecodingParams, message: str, + fallback: str) -> None: + """Change the backend to the specified fallback with a warning log, + or raise a ValueError if the `disable_fallback` option is specified.""" + if guided_params.disable_fallback: + raise ValueError(message) + + logger.warning("%s Falling back to use %s instead.", message, fallback) + guided_params.backend = fallback + + # `auto` was added for V1 to explicitly declare a mode that has fallbacks + # in place. If that is specified with V0, treat it as `xgrammar`, as we have + # fallbacks enabled for that and it is the V0 default. + if guided_params.backend == "auto": + guided_params.backend = "xgrammar" + + # lm-format-enforce doesn't support grammar, fallback to xgrammar + if guided_params.backend == "lm-format-enforcer": + if guided_params.grammar is not None: + fallback_or_error( + guided_params, + "lm-format-enforcer does not support grammar guided decoding.", + "xgrammar") + + # lm-format-enforcer doesn't support some JSON schema features + elif (guided_params.json is not None + and has_lmf_unsupported_json_features(guided_params.json)): + fallback_or_error( + guided_params, + "lm-format-enforcer does not support advanced JSON schema " + "features like patterns or numeric ranges.", "outlines") + + if guided_params.backend == "xgrammar": + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( + xgr_installed) + + # xgrammar doesn't support some JSON schema features + if (guided_params.json is not None and + has_xgrammar_unsupported_json_features(guided_params.json)): + fallback_or_error( + guided_params, + "xgrammar does not support advanced JSON schema features like " + "string length, item limits, or property bounds.", "outlines") + + # xgrammar only supports GBNF grammars, so we must convert Lark. + # We must check if the grammar is likely Lark and if that + # grammar is convertible to GBNF + elif (guided_params.grammar is not None + and grammar_is_likely_lark(guided_params.grammar)): + try: + convert_lark_to_gbnf(guided_params.grammar) + except Exception: + fallback_or_error( + guided_params, + "xgrammar does not support Lark grammars and the " + "grammar failed to convert to GBNF.", "outlines") + + # If the xgrammar module cannot be imported successfully, + # we should still allow users to use guided decoding with a fallback. + elif not xgr_installed: + fallback_or_error( + guided_params, + "xgrammar module cannot be imported successfully.", "outlines") + + if (guided_params.backend == "outlines" + and guided_params.json_object is not None): + # outlines doesn't support json_object, fallback to guidance + fallback_or_error(guided_params, + "outlines does not support json_object.", "guidance") + + return guided_params + + +async def get_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: + + reasoner = None + if reasoning_backend: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) + + guided_params = maybe_backend_fallback(guided_params) + + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_outlines_guided_decoding_logits_processor) + return await get_outlines_guided_decoding_logits_processor( + guided_params, tokenizer, reasoner) + if guided_params.backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config, reasoner) + if guided_params.backend == 'guidance': + from vllm.model_executor.guided_decoding.guidance_decoding import ( + get_local_guidance_guided_decoding_logits_processor) + return get_local_guidance_guided_decoding_logits_processor( + guided_params, tokenizer) + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" + ) + + +def get_local_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoning_backend: str | None = None) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) + + reasoner = None + if reasoning_backend: + reasoner_class = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoner = reasoner_class(tokenizer) + + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_local_outlines_guided_decoding_logits_processor) + return get_local_outlines_guided_decoding_logits_processor( + guided_params, tokenizer, reasoner) + if guided_params.backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config, reasoner) + if guided_params.backend == 'guidance': + from vllm.model_executor.guided_decoding.guidance_decoding import ( + get_local_guidance_guided_decoding_logits_processor) + return get_local_guidance_guided_decoding_logits_processor( + guided_params, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'" + ) diff --git a/vllm/model_executor/guided_decoding/guidance_decoding.py b/vllm/model_executor/guided_decoding/guidance_decoding.py new file mode 100644 index 0000000..05b6a1c --- /dev/null +++ b/vllm/model_executor/guided_decoding/guidance_decoding.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + +import llguidance +from regex import escape as regex_escape +from transformers import PreTrainedTokenizerBase + +from vllm.model_executor.guided_decoding.guidance_logits_processors import ( + GuidanceLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams +from vllm.v1.structured_output.backend_guidance import ( + process_for_additional_properties) + + +def get_local_guidance_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + """ + + grm = "" + any_whitespace = not guided_params.disable_any_whitespace + if (guide_json := guided_params.json) is not None: + # Optionally set additionalProperties to False at the top-level + # By default, other backends do not allow additional top-level + # properties, so this makes guidance more similar to other backends + if guided_params.disable_additional_properties: + if not isinstance(guide_json, str): + guide_json = json.dumps(guide_json) + guide_json = process_for_additional_properties(guide_json) + + grm = llguidance.LLMatcher.grammar_from_json_schema( + guide_json, + overrides={"whitespace_pattern": guided_params.whitespace_pattern}, + defaults={ + "whitespace_flexible": any_whitespace, + }) + elif guided_params.json_object: + grm = llguidance.LLMatcher.grammar_from_json_schema( + '{"type": "object"}', + overrides={"whitespace_pattern": guided_params.whitespace_pattern}, + defaults={ + "whitespace_flexible": any_whitespace, + }) + elif guided_params.regex: + grm = llguidance.grammar_from("regex", guided_params.regex) + elif guided_params.choice: + # choice just uses regex + choices = (regex_escape(str(choice)) + for choice in guided_params.choice) + choices_regex = "(" + "|".join(choices) + ")" + grm = llguidance.grammar_from("regex", choices_regex) + elif guided_params.grammar: + # this supports Lark and GBNF + grm = llguidance.grammar_from("grammar", guided_params.grammar) + + if grm: + return GuidanceLogitsProcessor(grm, tokenizer) + + raise ValueError("Unknown guided decoding mode") diff --git a/vllm/model_executor/guided_decoding/guidance_logits_processors.py b/vllm/model_executor/guided_decoding/guidance_logits_processors.py new file mode 100644 index 0000000..379b5ea --- /dev/null +++ b/vllm/model_executor/guided_decoding/guidance_logits_processors.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import os +from typing import Any + +import llguidance +import llguidance.hf +import llguidance.torch +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class GuidanceLogitsProcessor: + """Base Guidance Logits Processor""" + + cached_tokenizers: dict[str, Any] = {} + + def __init__( + self, + grammar: str, + tokenizer: PreTrainedTokenizerBase, + ) -> None: + """Base Guidance Logits Processor + + Args: + grammar (str) + grammar to guide the generation + tokenizer (PreTrainedTokenizerBase) + model's tokenizer + """ + self.grammar = grammar + self.tokenizer = tokenizer + self.tokenizer_name = tokenizer.name_or_path + self.ll_tokenizer = None + self.ll_matcher = None + self.bitmask = None + self.new_sampling = False + self.initialized = False + + def clone(self) -> "GuidanceLogitsProcessor": + cloned = copy.copy(self) + if self.initialized: + cloned.ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, # type: ignore[assignment] + self.grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + self.bitmask = llguidance.torch.allocate_token_bitmask( + 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined] + return cloned + + def _initialize(self): + if self.initialized: + return + + ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path, + None) + if ll_tokenizer is None: + ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) + self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer + + self.ll_tokenizer = ll_tokenizer + self.ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, + self.grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + + # create reusable bitmask + self.bitmask = llguidance.torch.allocate_token_bitmask( + 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined] + + self.initialized = True + + def __call__( + self, + input_ids: list[int], + scores: torch.Tensor, + ) -> torch.Tensor: + # we initialize the guidance model here + # to avoid pickling ll_tokenizer and ll_interpreter + self._initialize() + + if self.new_sampling and len(input_ids) > 0: + self.ll_matcher.consume_token( # type: ignore[attr-defined] + input_ids[-1]) + err = self.ll_matcher.get_error() # type: ignore[attr-defined] + if err: + logger.warning("Error in LLMatcher: %s", err) + + llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, + 0) + llguidance.torch.apply_token_bitmask_inplace( + scores, + self.bitmask.to(scores.device)) # type: ignore[attr-defined] + + self.new_sampling = True + + return scores diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 0000000..fa97b6d --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import Optional, TypedDict, Union + + +# These classes are deprecated, see SamplingParams +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[dict, str] + guided_regex: str + guided_choice: list[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + +@dataclass +class GuidedDecodingRequest: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[dict, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[list[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + structural_tag: Optional[str] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum(x is not None + for x in (self.guided_json, self.guided_regex, + self.guided_choice, self.guided_grammar, + self.guided_json_object, + self.structural_tag)) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000..f9b51f4 --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from transformers import PreTrainedTokenizerBase + +from vllm.logits_process import LogitsProcessor +from vllm.sampling_params import GuidedDecodingParams + + +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_params.json: + schema_dict = _normalize_json_schema_object(guided_params.json) + character_level_parser = JsonSchemaParser(schema_dict) + elif guided_params.choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_params.choice]) + elif guided_params.regex: + character_level_parser = RegexParser(guided_params.regex) + elif guided_params.grammar: + # CFG grammar not supported by LMFE + raise ValueError("Cannot construct a guided decoding logits processor" + " using the grammar option with the" + " lm_format_enforcer backend.") + elif guided_params.json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + raise AssertionError(f"Unsupported schema type {schema}") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py new file mode 100644 index 0000000..26c2d95 --- /dev/null +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import concurrent.futures +import os +from enum import Enum +from json import dumps as json_dumps +from typing import Optional, Union + +from regex import escape as regex_escape +from transformers import PreTrainedTokenizerBase + +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.reasoning import ReasoningParser +from vllm.sampling_params import GuidedDecodingParams + + +class GuidedDecodingMode(Enum): + JSON = "json" + REGEX = "regex" + CHOICE = "choice" + GRAMMAR = "grammar" + + +# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark +# the main difference is that we changed the start: value to +# start: object | array, so we are denying scalar values as the root of the +# JSON. Starting with scalars as the root seems to cause llama to generate +# without stop. +JSON_GRAMMAR = r""" +?start: object | array + +?value: object +| array +| UNESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : UNESCAPED_STRING ":" value + +%import common.UNESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS + +%ignore WS +""" + +global_thread_pool = None # used for generating logits processor fsm + +# It's not yet clear that using more provides a benefit, and it could +# potentially starve other processes on the machine. We'll cap this for now and +# adjust later if testing proves it to help overcome a bottleneck. +_MAX_THREADPOOL_WORKERS = 16 + + +async def get_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + global global_thread_pool + guide, mode = _get_guide_and_mode(guided_params) + if not guide or not mode: + return None + + if global_thread_pool is None: + max_workers = os.cpu_count() or 2 + if max_workers > _MAX_THREADPOOL_WORKERS: + max_workers = _MAX_THREADPOOL_WORKERS + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) + loop = asyncio.get_running_loop() + + return await loop.run_in_executor(global_thread_pool, + _get_logits_processor, guide, tokenizer, + mode, guided_params.whitespace_pattern, + reasoner) + + +def get_local_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_params) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_params.whitespace_pattern, reasoner) + + +def _get_guide_and_mode( + guided_params: GuidedDecodingParams +) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]: + if guided_params.json: + if isinstance(guided_params.json, dict): + # turn dict into hashable string + json = json_dumps(guided_params.json) + else: + json = guided_params.json + return json, GuidedDecodingMode.JSON + elif guided_params.regex: + return guided_params.regex, GuidedDecodingMode.REGEX + elif guided_params.choice: + # choice just uses regex + choices = [ + regex_escape(str(choice)) for choice in guided_params.choice + ] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex, GuidedDecodingMode.CHOICE + elif guided_params.grammar: + return guided_params.grammar, GuidedDecodingMode.GRAMMAR + elif guided_params.json_object: + return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + else: + return None, None + + +def _get_logits_processor( + guide: str, + tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None], + reasoner: Optional[ReasoningParser], +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: + if mode == GuidedDecodingMode.JSON: + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, + reasoner) + elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: + return RegexLogitsProcessor(guide, tokenizer, reasoner) + elif mode == GuidedDecodingMode.GRAMMAR: + return CFGLogitsProcessor(guide, tokenizer, reasoner) + else: + raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py new file mode 100644 index 0000000..4ef4db7 --- /dev/null +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024- the Outlines developers +# This file is adapted from +# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +from collections import defaultdict +from functools import lru_cache +from typing import Callable, Optional, Union + +import numpy as np +import torch +from outlines import grammars +from outlines.caching import cache, disable_cache +from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, + RegexGuide, Write) +from outlines.fsm.parsing import PartialLark +from outlines_core.fsm.json_schema import build_regex_from_schema +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.reasoning import ReasoningParser + +logger = init_logger(__name__) + +if envs.VLLM_V0_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") +else: + disable_cache() + + +class BaseLogitsProcessor: + + def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): + self._guide: Guide = guide + self._reasoner: Optional[ReasoningParser] = reasoner + # CFGState is used for the FSM state for CFGGuide + self._fsm_state: defaultdict[int, Union[int, + CFGState]] = defaultdict(int) + + def clone(self) -> "BaseLogitsProcessor": + cloned = copy.copy(self) + cloned._guide = self._guide.copy() + cloned._fsm_state = copy.deepcopy(self._fsm_state) + return cloned + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--reasoning-parser` is set. + if self._reasoner is not None: + if not self._reasoner.is_reasoning_end(input_ids): + return scores + else: + # Remove the reasoning tokens from the input_ids + # We need this because our implementation relies on the + # hash of the input_ids to store the FSM state. + input_ids = self._reasoner.extract_content_ids(input_ids) + + seq_id = hash(tuple(input_ids)) + + if len(input_ids) > 0: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self._fsm_state[seq_id] = self._guide.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = PartialLark( + self._guide.cfg_string, + parser="lalr", + import_paths=[grammars.GRAMMAR_PATH], + ) + self._fsm_state[seq_id] = CFGState( + parser_state=self._guide.parser.parse(""), prev_token=None) + + instruction = self._guide.get_next_instruction( + state=self._fsm_state[seq_id]) + + if type(instruction) == Generate: # noqa: E721 + allowed_tokens = instruction.tokens + elif type(instruction) == Write: # noqa: E721 + # TODO: support fast forward tokens + allowed_tokens = [instruction.tokens[0]] + else: + raise TypeError( + f"Unsupported instruction type {type(instruction)}") + + mask = torch.full((scores.shape[-1], ), + -torch.inf, + device=scores.device) + # The tokenizer may support more token ids than the model can generate, + # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256 + # but scores.shape == torch.Size([128256]) + # Using NumPy is faster for filtering token ids + allowed_tokens = np.array(allowed_tokens, dtype=np.int64) + allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) + allowed_tokens = allowed_tokens.masked_select( + allowed_tokens < scores.shape[-1]) + mask.index_fill_(0, allowed_tokens, 0) + if current_platform.is_hpu(): + # Workaround for HPU bug where add_() raise RuntimeError: + # synNodeCreateWithId failed for node: strided_insert + # with synStatus 1 [Invalid argument], hopefully it will + # be fixed in the future releases of the HPU runtime. + scores = scores.add(mask) + else: + scores.add_(mask) + return scores + + +class RegexLogitsProcessor(BaseLogitsProcessor): + + @classmethod + @cache() + def _get_guide(cls, regex_string: str, + tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return RegexGuide.from_regex(regex_string, tokenizer) + + def __init__( + self, + regex_string: str, + tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], + ): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer + + """ + super().__init__( + RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, schema: Union[str, dict, BaseModel], + tokenizer: PreTrainedTokenizerBase, + whitespace_pattern: Union[str, None], + reasoner: Optional[ReasoningParser]): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to + generate + tokenizer + The model's tokenizer + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact + string literals) + Example: allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + f"a Pydantic object, a dictionary or a string that contains " + f"the JSON Schema specification") + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer, reasoner) + + +class CFGLogitsProcessor(BaseLogitsProcessor): + + @classmethod + @cache() + def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return CFGGuide(cfg, tokenizer) + + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser]): + """Compile the FSM that drives the context free grammar generation. + + Parameters + ---------- + cfg + A string that represents a context-free grammar + tokenizer + The model's tokenizer + + """ + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), + reasoner) + self._guide = self._guide.copy() + + def clone(self) -> "CFGLogitsProcessor": + cloned = copy.copy(self) + cloned._fsm_state = copy.deepcopy(self._fsm_state) + cloned._guide = self._guide.copy() + return cloned + + +@lru_cache(maxsize=32) +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if (type(token) is str and token.startswith(SPIECE_UNDERLINE) + or token == "<0x20>"): + return " " + string + + return string + + def change_decoder( + decoder: Callable[[list[int]], + str]) -> Callable[[list[int]], list[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: list[int]) -> list[str]: + if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 + and isinstance(inp_tokens[0], list)): + inp_tokens = inp_tokens[0] + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer diff --git a/vllm/model_executor/guided_decoding/utils.py b/vllm/model_executor/guided_decoding/utils.py new file mode 100644 index 0000000..8fdfa98 --- /dev/null +++ b/vllm/model_executor/guided_decoding/utils.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import regex as re + + +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and ("multipleOf" in obj): + return True + + # Check for array unsupported keywords + if obj.get("type") == "array" and any(key in obj for key in [ + "uniqueItems", "contains", "minContains", "maxContains", + "minItems", "maxItems" + ]): + return True + + # Unsupported keywords for strings + if obj.get("type") == "string" and any( + key in obj for key in ["minLength", "maxLength", "format"]): + return True + + # Unsupported keywords for objects + if obj.get("type") == "object" and any(key in obj for key in [ + "minProperties", "maxProperties", "propertyNames", + "patternProperties" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def has_lmf_unsupported_json_features(schema: dict) -> bool: + """ + Check if JSON schema contains features unsupported + by lm_format_enforcer. + + Known issues: + - Regex patterns: + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + """ + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for GBNF rule definition + if '::=' in line: + return False + + return True + + +def convert_lark_to_gbnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to GBNF format. + + GBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in GBNF format + + Examples: + >>> print(convert_lark_to_gbnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000..bdd3a1a --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# noqa: UP007 +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import regex as re +import torch + +import vllm.envs +from vllm.logger import init_logger + +try: + import xgrammar as xgr + xgr_installed = True +except ImportError: + xgr_installed = False + pass + +from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, + grammar_is_likely_lark) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.reasoning import ReasoningParser + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + reasoner: ReasoningParser | None, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + model_config=model_config, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config, reasoner) + + +@dataclass(frozen=True) +class TokenizerData: + """Immutable container for cached tokenizer data.""" + metadata: str + encoded_vocab: list[str] = field(default_factory=list) + + +class TokenizerDataCache: + """Cache manager for tokenizer data to avoid repeated processing.""" + _cache: dict[int, TokenizerData] = {} + + @classmethod + def get_tokenizer_data( + cls, + tokenizer: PreTrainedTokenizer, + /, + *, + tokenizer_hash: int, + vocab_size: int, + ) -> TokenizerData: + + if tokenizer_hash not in cls._cache: + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + # NOTE: We will need to use lm_head's vocab_size + # to determine correct special_token_ids for this tokenizer. + # See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501 + vocab_size=vocab_size, + ) + metadata = json.loads(tokenizer_info.dump_metadata()) + + # Vendored from xgrammar logic to get encoded_vocab + # https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501 + try: + vocab_dict = tokenizer.get_vocab() + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + + # maintain tokenizer's indexing + encoded_vocab = [""] * tokenizer_info.vocab_size + for token, idx in vocab_dict.items(): + if idx < tokenizer_info.vocab_size: + encoded_vocab[idx] = token + + if isinstance(tokenizer, MistralTokenizer): + # REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + metadata.update({ + "vocab_type": xgr.VocabType.BYTE_FALLBACK, + "add_prefix_space": True + }) + + cls._cache[tokenizer_hash] = TokenizerData( + encoded_vocab=encoded_vocab, + metadata=json.dumps(metadata), + ) + + return cls._cache[tokenizer_hash] + + +class GrammarCompilerCache: + """ + Cache for GrammarCompiler instances based on tokenizer. + + This cache reduces the overhead of creating new compiler instances when + using the same tokenizer configuration. + """ + _cache: dict[str, xgr.GrammarCompiler] = {} + + @classmethod + def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: + cache_key = str(config.tokenizer_hash) + + if cache_key not in cls._cache: + config_data = config.tokenizer_data + + # In TokenizerDataCache.get_tokenizer_data, a serializable + # tokenizer_data is created and cached. This data is used to build + # a tokenizer_info and create an xgrammar compiler. + tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata( + encoded_vocab=config_data.encoded_vocab, + metadata=config_data.metadata, + ) + cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024 + cls._cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, + max_threads=config.max_threads, + cache_enabled=True, + cache_limit_bytes=cache_size, + ) + + return cls._cache[cache_key] + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + tokenizer_hash: int + tokenizer_data: TokenizerData + json_str: str | None = None + grammar_str: str | None = None + json_object: bool | None = None + any_whitespace: bool = True + regex_str: str | None = None + max_threads: int = 8 + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + tokenizer_hash = hash(tokenizer) + tokenizer_data = TokenizerDataCache.get_tokenizer_data( + tokenizer, + tokenizer_hash=tokenizer_hash, + vocab_size=model_config.hf_text_config.vocab_size, + ) + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + + any_whitespace = not guided_params.disable_any_whitespace + + # Check and log if model with xgrammar and whitespace have history + # of runaway generation of whitespaces. + # References: + # https://github.com/vllm-project/vllm/pull/12744 + # https://github.com/mlc-ai/xgrammar/issues/212 + model_with_warn = None + + if 'Mistral' in model_config.model: + model_with_warn = 'Mistral' + elif 'Qwen' in model_config.model: + model_with_warn = 'Qwen' + + if model_with_warn is not None and any_whitespace: + logger.info_once( + "%s model detected, consider setting `disable_any_whitespace` to prevent runaway generation of whitespaces.", # noqa: E501 + model_with_warn, + ) + # Validate the schema and raise ValueError here if it is invalid. + # This is to avoid exceptions in model execution, which will crash + # the engine worker process. + try: + xgr.Grammar.from_json_schema(json_str, + any_whitespace=any_whitespace) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls(json_str=json_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + any_whitespace=any_whitespace) + elif guided_params.grammar: + # XGrammar only supports GBNF grammars, so we must convert Lark + if grammar_is_likely_lark(guided_params.grammar): + try: + grammar_str = convert_lark_to_gbnf(guided_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to GBNF. " + "Please either use GBNF grammar directly or specify" + " --guided-decoding-backend=outlines.\n" + f"Conversion error: {str(e)}") from e + else: + grammar_str = guided_params.grammar + + # Validate the grammar and raise ValueError here if it is invalid. + # This is to avoid exceptions in model execution, which will crash + # the engine worker process. + try: + xgr.Grammar.from_ebnf(grammar_str) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls(grammar_str=grammar_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data) + elif guided_params.json_object: + return cls( + json_object=True, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) + elif guided_params.choice: + choice_str = GrammarConfig.choice_as_grammar(guided_params.choice) + try: + xgr.Grammar.from_ebnf(choice_str) + except RuntimeError as err: + raise ValueError(str(err)) from err + + return cls( + grammar_str=choice_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) + elif guided_params.regex: + return cls( + regex_str=guided_params.regex, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads, + tokenizer_data=tokenizer_data, + ) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + @staticmethod + def escape_ebnf_string(s: str) -> str: + """Escape special characters in a EBNF string.""" + # Escape double quotes and backslashes + return re.sub(r'(["\\])', r'\\\1', s) + + @staticmethod + def choice_as_grammar(choice: list[str] | None) -> str: + if choice is None: + raise ValueError("Choice is not set") + escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice) + grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + return grammar + + @staticmethod + def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo: + return xgr.TokenizerInfo.from_vocab_and_metadata( + encoded_vocab=tokenizer_data.encoded_vocab, + metadata=tokenizer_data.metadata, + ) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + reasoner: ReasoningParser | None = None + + ctx: xgr.CompiledGrammar | None = None + tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] + token_bitmask: torch.Tensor = None # type: ignore[assignment] + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = field(default=1) + prefilled: bool = field(default=False) + + def __post_init__(self): + if self.tokenizer_info is None: + self.tokenizer_info = self.config.tokenizer_info( + self.config.tokenizer_data) + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config, 'reasoner': self.reasoner} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + self.reasoner = state['reasoner'] + + self.tokenizer_info = GrammarConfig.tokenizer_info( + self.config.tokenizer_data) + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None # type: ignore[assignment] + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = GrammarCompilerCache.get_compiler(self.config) + if self.config.json_str is not None: + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema(self.config.json_str, + any_whitespace=any_whitespace) + elif self.config.grammar_str is not None: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + elif self.config.json_object: + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema('{"type": "object"}', + any_whitespace=any_whitespace) + elif self.config.regex_str: + self.ctx = compiler.compile_regex(self.config.regex_str) + else: + raise ValueError( + "Invalid configuration for xgrammar logits processor") + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + + # Skip the structured logits processing if reasoning is not finished. + # reasoner is not None only when `--reasoning-parser` is set. + if self.reasoner is not None and \ + not self.reasoner.is_reasoning_end( + input_ids): + return scores + + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.tokenizer_info.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + # @ubospica: ideally, fill_next_token_bitmask should be + # parallelized with model decoding + # See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303 + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + # token_bitmask is a CPU tensor for use with accept_token and + # fill_next_token_bitmask so we move it to the device of scores + device_type = scores.device.type + dtype = scores.dtype + if device_type != "cuda": + # xgrammar on cpu only supports float32 scores + # see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22 + scores = scores.to("cpu").float().unsqueeze(0) + + # Note: In this method, if the tensors have different dimensions + # on CPU device fails, but on GPU it runs without error. Hence the + # unsqueeze above for scores, to match the token bitmask shape + xgr.apply_token_bitmask_inplace( + scores, self.token_bitmask.to(scores.device, non_blocking=True)) + if device_type != "cuda": + scores = scores.to(dtype).to(device_type).squeeze() + + return scores + + def clone(self) -> XGrammarLogitsProcessor: + """Create a new instance with shared compiled grammar + but separate state""" + new_processor = XGrammarLogitsProcessor(self.config, self.reasoner, + None, self.tokenizer_info) + + # Share the compiled grammar context (immutable after compilation) + new_processor.ctx = self.ctx + + # Create fresh matchers for the new sequence + if self.ctx is not None: + new_processor.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + + # Create a new token bitmask with the same size + if hasattr(self, 'token_bitmask') and self.token_bitmask is not None: + new_processor.token_bitmask = self.token_bitmask + + # Copy simple attributes + new_processor.batch_size = self.batch_size + # Reset prefilled state for new sequence + new_processor.prefilled = False + + return new_processor diff --git a/vllm/model_executor/layers/__init__.py b/vllm/model_executor/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py new file mode 100644 index 0000000..93ad1f4 --- /dev/null +++ b/vllm/model_executor/layers/activation.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Custom activation functions.""" +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import LazyDict +import vllm.envs as envs + + +@CustomOp.register("fatrelu_and_mul") +class FatreluAndMul(CustomOp): + """An activation function for FATReLU. + + The function computes x -> FATReLU(x[:d]) * x[d:] where + d = x.shape[-1] // 2. + This is used in openbmb/MiniCPM-S-1B-sft. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, threshold: float = 0.): + super().__init__() + self.threshold = threshold + if current_platform.is_cuda_alike(): + self.op = torch.ops._C.fatrelu_and_mul + elif current_platform.is_cpu(): + self._forward_method = self.forward_native + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x1 = x[..., :d] + x2 = x[..., d:] + x1 = F.threshold(x1, self.threshold, 0.0) + return x1 * x2 + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x, self.threshold) + return out + + +@CustomOp.register("silu_and_mul") +class SiluAndMul(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.silu_and_mul + self.op_opt = torch.ops._C.silu_and_mul_opt + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if envs.VLLM_USE_OPT_OP: + self.op_opt(out, x) + else: + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + def forward_neuron(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d]) + result = s * x_reshaped[:, d:] + return result.view(*x.shape[:-1], d) + + +@CustomOp.register("mul_and_silu") +class MulAndSilu(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike(): + self.op = torch.ops._C.mul_and_silu + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul + elif current_platform.is_cpu(): + self._forward_method = self.forward_native + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return x[..., :d] * F.silu(x[..., d:]) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + # TODO implement forward_xpu for MulAndSilu + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + + +@CustomOp.register("gelu_and_mul_sparse") +class GeluAndMulSparse(CustomOp): + """An activation function for GeluAndMulSparse. + This activation function is used in Gemma3n. It computes: + up_proj = self.up_proj(x) + gate_proj = self.gate_proj(x) + gate_proj = self._gaussian_topk(gate_proj) # sparsity + activations = self.act_fn(gate_proj) # gelu + down_proj = self.down_proj(activations * up_proj) + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self, activation_sparsity: float, approximate: str = "none"): + super().__init__() + # Gelu. + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + + # Sparsity. + if activation_sparsity == 0.0: + raise ValueError( + "activation_sparsity is 0.0. Please use GeluAndMul.") + target_sparsity_tensor = torch.tensor(activation_sparsity, + dtype=torch.float32) + normal_dist = torch.distributions.normal.Normal(0, 1) + self.std_multiplier = normal_dist.icdf(target_sparsity_tensor) + + def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor: + """Get % sparse percentile of the Gaussian distribution.""" + # NOTE(rob): for TP>1, we could all-gather to get the means/std. + # But we do not do this because in expectation they are the same + # and in practice the eval scores are good without gathering. + mean = torch.mean(x, dim=-1, keepdim=True) + std = torch.std(x, dim=-1, keepdim=True, unbiased=False) + cutoff_x = mean + std * self.std_multiplier + return nn.functional.relu(x - cutoff_x) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + out = self._gaussian_topk(x[..., :d]) + out = F.gelu(out, approximate=self.approximate) + return out * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + +@CustomOp.register("gelu_and_mul") +class GeluAndMul(CustomOp): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def __init__(self, approximate: str = "none"): + super().__init__() + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + if approximate == "none": + self.op = torch.ops._C.gelu_and_mul + self.op_opt = torch.ops._C.gelu_and_mul_opt + elif approximate == "tanh": + self.op = torch.ops._C.gelu_tanh_and_mul + self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + if approximate == "none": + self.op = ipex_ops.gelu_and_mul + else: + self.op = ipex_ops.gelu_tanh_and_mul + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if envs.VLLM_USE_OPT_OP: + self.op_opt(out, x) + else: + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + + +@CustomOp.register("gelu_new") +class NewGELU(CustomOp): + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.gelu_new + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_new + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * + (x + 0.044715 * torch.pow(x, 3.0)))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.op(x) + + +@CustomOp.register("gelu_fast") +class FastGELU(CustomOp): + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.gelu_fast + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_fast + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * + (1.0 + 0.044715 * x * x))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.op(x) + + +@CustomOp.register("quick_gelu") +class QuickGELU(CustomOp): + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.gelu_quick + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_quick + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + self.op(out, x) + return out + + # TODO implement forward_xpu for QuickGELU + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + + +@CustomOp.register("relu2") +class ReLUSquaredActivation(CustomOp): + """ + Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + """ + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return torch.square(F.relu(x)) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.act = act_module + self.input_is_parallel = input_is_parallel + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, + tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.scales = nn.Parameter( + torch.empty(intermediate_size_per_partition, dtype=params_dtype)) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(x) / self.scales + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +_ACTIVATION_REGISTRY = LazyDict({ + "gelu": + lambda: nn.GELU(), + "gelu_fast": + lambda: FastGELU(), + "gelu_new": + lambda: NewGELU(), + "gelu_pytorch_tanh": + lambda: nn.GELU(approximate="tanh"), + "relu": + lambda: nn.ReLU(), + "relu2": + lambda: ReLUSquaredActivation(), + "silu": + lambda: nn.SiLU(), + "quick_gelu": + lambda: QuickGELU(), +}) + + +def get_act_fn(act_fn_name: str) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_REGISTRY[act_fn_name] + + +_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ + "gelu": lambda: GeluAndMul(), + "silu": lambda: SiluAndMul(), + "geglu": lambda: GeluAndMul(), +}) + + +def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: + """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 0000000..3d40879 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Any, Optional + +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize) +from vllm.triton_utils import HAS_TRITON + +_config: Optional[dict[str, Any]] = None + + +@contextmanager +def override_config(config): + global _config + old_config = _config + _config = config + yield + _config = old_config + + +def get_config() -> Optional[dict[str, Any]]: + return _config + + +__all__ = [ + "FusedMoE", + "FusedMoEConfig", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", + "FusedMoEPermuteExpertsUnpermute", + "FusedMoEActivationFormat", + "FusedMoEPrepareAndFinalize", + "override_config", + "get_config", +] + +if HAS_TRITON: + # import to register the custom ops + import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa + import vllm.model_executor.layers.fused_moe.fused_moe # noqa + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) + + __all__ += [ + "fused_moe", + "fused_topk", + "fused_experts", + "get_config_file_name", + "grouped_topk", + "cutlass_moe_fp8", + "cutlass_moe_fp4", + "CutlassExpertsFp8", + "TritonExperts", + "BatchedTritonExperts", + "DeepGemmExperts", + "BatchedDeepGemmExperts", + "TritonOrDeepGemmExperts", + "BatchedTritonOrDeepGemmExperts", + ] diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py new file mode 100644 index 0000000..a8788e3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -0,0 +1,298 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.triton_utils import tl, triton + +logger = init_logger(__name__) + + +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + + # Stride for counts (elements) + stride_counts_e, + + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK) + cols = cols.to(tl.int64) + mask_h = cols < BLOCK + + t = tl.zeros([], tl.int64) + while t < n_tokens: + base_i_offset = (e * stride_i_e + t * stride_i_t + + g * GROUP_SIZE * stride_i_h) + base_yq_offset = (e * stride_yq_e + t * stride_yq_t + + g * GROUP_SIZE * stride_yq_h) + base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g + + mask = mask_h + x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, + mask=mask, + other=0.0).to(tl.float32) + y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + + cols * stride_i_h, + mask=mask, + other=0.0).to(tl.float32) + + x = x * (1.0 / (1.0 + tl.exp(-x))) + y = x * y2 + + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset, y_s) + + t += 1 + + +def silu_mul_fp8_quant_deep_gemm( + y: torch.Tensor, # (E, T, 2*H) float32 + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + group_size: int = 128, + eps: float = 1e-10, +): + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. + * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = H // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ + "tokens_per_expert must be shape (E,)" + tokens_per_expert = tokens_per_expert.to(device=y.device, + dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided((E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + BLOCK=group_size, + num_warps=4, + ) + + return y_q, y_s + + +class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + # The Deep Gemm kernels only support block size of 128 + DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] + + def __init__(self, + max_num_tokens: int, + num_dispatchers: int, + block_shape: list[int], + per_act_token_quant=False): + """ + max_num_tokens: Maximum number of tokens from a DP Rank + num_dispatchers: The number of DP dispatchers. + block_shape: Block quantization block shape. + per_act_token_quant: Per activation token quantization flag. + """ + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert a.dim() == 2 + # FIXME (varun): We should be able to dispatch only from the leader + # DP ranks in the case of TP > 1. At the moment, all the Ranks + # end up sending their tokens. This needs to be fixed. + num_dispatchers = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = a.size( + 0) if self.max_num_tokens is None else self.max_num_tokens + workspace13 = (num_experts, max_num_tokens * num_dispatchers, + max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) + output = (num_experts, max_num_tokens * num_dispatchers, K) + return (workspace13, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + import deep_gemm as dg + assert hidden_states.ndim == 3 + assert self.block_shape is not None + + a1q = hidden_states + _, N, K = w1.size() + + assert w2.size(1) == K + + E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) + + # (from deepgemm docs) : A value hint (which is a value on CPU) + # for the M expectation of each batch, correctly setting this value + # may lead to better performance. + expected_m = max_num_tokens + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale), + (w1, w1_scale), + out=workspace1, + masked_m=expert_num_tokens, + expected_m=expected_m) + + assert expert_num_tokens is not None + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, + expert_num_tokens) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), + (w2, w2_scale), + out=output, + masked_m=expert_num_tokens, + expected_m=expected_m) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py new file mode 100644 index 0000000..0d67b4a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) + + +class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, + max_num_tokens: int, + num_dispatchers: int, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, + allow_deep_gemm: bool = False): + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + block_shape=block_shape, + per_act_token_quant=per_act_token_quant, + )) + self.allow_deep_gemm = allow_deep_gemm + + self.batched_triton_experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape, + ) + + self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 + and self.block_shape + == BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE) + + self.batched_deep_gemm_experts = BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + block_shape=self.block_shape, # type: ignore[arg-type] + ) if self.allow_deep_gemm else None + + assert (self.batched_deep_gemm_experts is not None + or self.batched_triton_experts is not None) + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.batched_triton_experts is not None: + assert (self.batched_deep_gemm_experts is None + or self.batched_deep_gemm_experts.activation_formats + == self.batched_triton_experts.activation_formats) + return self.batched_triton_experts.activation_formats + else: + assert self.batched_deep_gemm_experts is not None + return self.batched_deep_gemm_experts.activation_formats + + def supports_chunking(self) -> bool: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_chunking()) + and (bte is None or bte.supports_chunking())) + + def supports_expert_map(self) -> bool: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_expert_map()) + and (bte is None or bte.supports_expert_map())) + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm: + assert self.batched_deep_gemm_experts is not None + return self.batched_deep_gemm_experts.workspace_shapes( + a, aq, M, N, K, topk, global_num_experts, local_num_experts) + else: + assert self.batched_triton_experts is not None + return self.batched_triton_experts.workspace_shapes( + a, aq, M, N, K, topk, global_num_experts, local_num_experts) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + experts = (self.batched_deep_gemm_experts + if self.allow_deep_gemm else self.batched_triton_experts) + assert experts is not None + experts.apply(output, hidden_states, w1, w2, topk_ids, activation, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, + workspace2, expert_num_tokens) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py new file mode 100644 index 0000000..993f07f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -0,0 +1,460 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) + +import vllm.envs as envs +from vllm.config import ParallelConfig +from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.utils import cdiv + +logger = init_logger(__name__) + + +def _get_quant_config_quantization_args( + quant_config: Optional[QuantizationConfig], + prop_name: str, +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get(prop_name) + else: + return None + + +def get_quant_config_input_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, + "input_activations") + + +def get_quant_config_weight_quant( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + return _get_quant_config_quantization_args(quant_config, "weights") + + +# TODO (bnell): use scalar_type instead of bools? +def get_config_quant_dtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + use_int4_w4a8: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + +@dataclass +class FusedMoEQuantConfig: + # The post quantization activation type. + quant_dtype: Optional[torch.dtype] = None + per_act_token_quant: bool = False + per_out_ch_quant: bool = False + block_shape: Optional[list[int]] = None + + # TODO: add col major flag? + # add detailed quant info for input, intermediates, weights, etc? + + def __post_init__(self): + assert (not self.per_act_token_quant + or self.block_shape is None), "illegal quantization" + + @property + def is_quantized(self) -> bool: + return self.quant_dtype is not None + + @property + def is_per_act_token(self) -> bool: + return self.per_act_token_quant + + @property + def is_block_quantized(self) -> bool: + return self.block_shape is not None + + @property + def is_per_tensor(self) -> bool: + return not self.per_act_token_quant and self.block_shape is None + + def scale_shape( + self, + max_tokens: int, + hidden_dim: int, + ) -> Optional[tuple[int, int]]: + if self.is_quantized: + if self.is_block_quantized: + assert self.block_shape is not None + _, block_k = self.block_shape + k_tiles = cdiv(hidden_dim, block_k) + return (max_tokens, k_tiles) + elif self.is_per_act_token: + return (max_tokens, 1) + else: + return (1, 1) + else: + return None + + def batched_scale_shape( + self, + num_experts: int, + max_tokens: int, + hidden_dim: int, + ) -> Optional[tuple[int, int, int]]: + if self.is_quantized: + scale_shape = self.scale_shape(max_tokens, hidden_dim) + assert scale_shape is not None + return (num_experts, *scale_shape) + else: + return None + + @staticmethod + def make( + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool = False, + per_act_token_quant: bool = False, + per_out_ch_quant: bool = False, + block_shape: Optional[list[int]] = None, + ) -> "FusedMoEQuantConfig": + assert sum([ + int(flag) for flag in [ + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + use_int4_w4a8, + ] + ]) <= 1, "Quantization flags are mutually exclusive." + + quant_dtype = get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + ) + return FusedMoEQuantConfig( + quant_dtype, + per_act_token_quant, + per_out_ch_quant, + block_shape, + ) + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class FusedMoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + # The activation type. + in_dtype: torch.dtype + + quant_config: Optional[FusedMoEQuantConfig] = None + + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE + + def __post_init__(self): + if self.dp_size > 1: + logger.debug("Using FusedMoEConfig::max_num_tokens=%d", + self.max_num_tokens) + + assert self.max_num_tokens > 0 + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + if self.quant_config is not None: + return self.quant_config.quant_dtype + else: + return None + + @property + def block_shape(self) -> Optional[list[int]]: + if self.quant_config is not None: + return self.quant_config.block_shape + else: + return None + + @property + def per_act_token_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_act_token_quant + else: + return False + + @property + def per_out_ch_quant(self) -> bool: + if self.quant_config is not None: + return self.quant_config.per_out_ch_quant + else: + return False + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + + @staticmethod + def make( + num_experts: int, + experts_per_token: int, + hidden_dim: int, + num_local_experts: int, + moe_parallel_config: FusedMoEParallelConfig, + in_dtype: torch.dtype, + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config: Optional[Union[FusedMoEQuantConfig, + QuantizationConfig]] = None + ) -> "FusedMoEConfig": + + _quant_config: Optional[FusedMoEQuantConfig] = None + + if quant_config is not None and isinstance(quant_config, + QuantizationConfig): + if hasattr(quant_config, 'weight_block_size'): + block_shape = quant_config.weight_block_size + else: + block_shape = None + per_act_token_quant = False + per_out_ch_quant = False + quant_dtype: Optional[torch.dtype] = None + + input_quant = get_quant_config_input_quant(quant_config) + weight_quant = get_quant_config_weight_quant(quant_config) + + if input_quant is not None: + per_act_token_quant = (input_quant.strategy + == QuantizationStrategy.TOKEN + if input_quant is not None else False) + + if input_quant.num_bits == 8: + if input_quant.type == QuantizationType.FLOAT: + quant_dtype = torch.float8_e4m3fn + elif input_quant.type == QuantizationType.INT: + quant_dtype = torch.int8 + + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + if quant_dtype is None and isinstance(quant_config, Fp8Config): + quant_dtype = torch.float8_e4m3fn + + if weight_quant is not None: + per_out_ch_quant = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL) + + if quant_dtype is not None: + _quant_config = FusedMoEQuantConfig( + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) + else: + _quant_config = FusedMoEQuantConfig() + logger.warning_once("MoE DP setup unable to determine " + "quantization scheme or unsupported " + "quantization type. This model will " + "not run with DP enabled.") + else: + _quant_config = quant_config + + return FusedMoEConfig( + num_experts=num_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + num_local_experts=num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=in_dtype, + quant_config=_quant_config, + max_num_tokens=max_num_tokens, + ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..56c1a4e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..d3677be --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..265768f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..d3be23d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..589f5d3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000..2c78bfa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..4da841e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..2003567 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..e076615 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..ee89655 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..05aed8b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..555d173 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..e539335 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=128,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=128,device_name=BW200.json new file mode 100644 index 0000000..31baf5f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=128,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=128,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=128,device_name=BW200_nn.json new file mode 100644 index 0000000..4f7174e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=128,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=176,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=176,device_name=BW200_nn.json new file mode 100644 index 0000000..c5cc97b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=176,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=176,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=176,device_name=K100_AI_nn.json new file mode 100644 index 0000000..45d5275 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=176,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=BW200_nn.json new file mode 100644 index 0000000..78e88ed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=K100_AI_nn.json new file mode 100644 index 0000000..a99b64f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..e1c4cac --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..5de5605 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000..b506820 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json new file mode 100644 index 0000000..2221e99 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..74374c5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c275cec --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=BW200_nn.json new file mode 100644 index 0000000..b14ffc3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=K100_AI_nn.json new file mode 100644 index 0000000..7f9fbac --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b34b6e4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..60ccde1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e.json new file mode 100644 index 0000000..b0139b9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json new file mode 100644 index 0000000..ab169a0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..324ad7b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..ab6e155 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..249359f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9942546 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b9dc2d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4efc9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3559f33 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json new file mode 100644 index 0000000..03dfc73 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9c07695 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..beaac7f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=BW200_nn.json new file mode 100644 index 0000000..f2e2d66 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI,dtype=int4_w4a16.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI,dtype=int4_w4a16.json new file mode 100644 index 0000000..0c4eb66 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI,dtype=int4_w4a16.json @@ -0,0 +1,182 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + }, + "6144": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI_nn.json new file mode 100644 index 0000000..315a0a1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json new file mode 100644 index 0000000..ebff99e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..f10e394 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=BW200_nn.json new file mode 100644 index 0000000..186a5d8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=DCU_K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=DCU_K100_AI_nn.json new file mode 100644 index 0000000..8620b09 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=DCU_K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=K100_AI_nn.json new file mode 100644 index 0000000..8620b09 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..beeb5a6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} + diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200.json new file mode 100644 index 0000000..1fa444b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json new file mode 100644 index 0000000..0442038 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000..9262a74 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..d251f9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..0ecf814 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..51ad5b2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..ee51191 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..68793c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "17408": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "33792": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "41984": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "50176": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "58368": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..6129107 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..039a10e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..3793fca --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..51d03d8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000..26f9abd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..cd0cdbe --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..64be6e6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..0a6a6a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,218 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "5120": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "9216": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "13312": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "17408": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "25600": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "33792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "41984": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "50176": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "58368": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..ba9041d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json new file mode 100644 index 0000000..7a7508a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..dbf9a2d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json new file mode 100644 index 0000000..bbb2386 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..5705545 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3584": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..0611620 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=80,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=80,device_name=BW200_nn.json new file mode 100644 index 0000000..1dbad24 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=80,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=80,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=80,device_name=K100_AI_nn.json new file mode 100644 index 0000000..a839b27 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=80,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=96,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=96,device_name=BW200_nn.json new file mode 100644 index 0000000..26e8885 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=96,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=160,N=96,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=160,N=96,device_name=K100_AI_nn.json new file mode 100644 index 0000000..c4c6aa7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=160,N=96,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json new file mode 100644 index 0000000..43c249d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..43c249d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=BW200,dtype=int4_w4a16.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=BW200,dtype=int4_w4a16.json new file mode 100644 index 0000000..eea4465 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=BW200,dtype=int4_w4a16.json @@ -0,0 +1,173 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16.json new file mode 100644 index 0000000..d735553 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16_120.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16_120.json new file mode 100644 index 0000000..35a7273 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16_120.json @@ -0,0 +1,182 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + }, + "6144": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K500SM_AI,dtype=int4_w4a16_120.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K500SM_AI,dtype=int4_w4a16_120.json new file mode 100644 index 0000000..35a7273 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K500SM_AI,dtype=int4_w4a16_120.json @@ -0,0 +1,182 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + }, + "6144": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4dd00d1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json new file mode 100644 index 0000000..48f9697 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a8c0571 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json new file mode 100644 index 0000000..f1244c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2e692a1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..857d11e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a2ee05d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..63e1187 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e676960 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e676960 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=BW200.json new file mode 100644 index 0000000..fc1fda8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fc573cd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3e0ad0d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c6d7e96 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9264ca1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6fcf408 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c6eabea --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e676960 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=BW200.json new file mode 100644 index 0000000..250bc5b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=BW200_nn.json new file mode 100644 index 0000000..f6f3a62 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=DCU_K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=DCU_K100_AI_nn.json new file mode 100644 index 0000000..44f8191 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=DCU_K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16.json new file mode 100644 index 0000000..e579b24 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16_120.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16_120.json new file mode 100644 index 0000000..e3d279f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16_120.json @@ -0,0 +1,173 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI_nn.json new file mode 100644 index 0000000..44f8191 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K500SM_AI,dtype=int4_w4a16_120.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K500SM_AI,dtype=int4_w4a16_120.json new file mode 100644 index 0000000..e3d279f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K500SM_AI,dtype=int4_w4a16_120.json @@ -0,0 +1,173 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 0 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "num_ldmatrixes": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..21f6022 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=BW200_nn.json new file mode 100644 index 0000000..e07fc51 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=DCU_K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=DCU_K100_AI_nn.json new file mode 100644 index 0000000..86a7bcc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=DCU_K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=K100_AI_nn.json new file mode 100644 index 0000000..86a7bcc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=512,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=48,N=320,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=48,N=320,device_name=BW200_nn.json new file mode 100644 index 0000000..1e0c1ff --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=48,N=320,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=48,N=320,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=48,N=320,device_name=K100_AI_nn.json new file mode 100644 index 0000000..b40a1f4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=48,N=320,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..d09508b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..746463a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..bbdb9ad --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..43584b1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..8cc6c64 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..39a9912 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..05b5463 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..d4c9ddd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..c17a4ec --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..170ae7f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW200.json new file mode 100644 index 0000000..0a31de1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW200.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW200_nn.json new file mode 100644 index 0000000..9d9d561 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW200_nn.json @@ -0,0 +1,182 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW3000.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW3000.json new file mode 100644 index 0000000..625a419 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=BW3000.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=K100_AI.json new file mode 100644 index 0000000..f61512f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=K100_AI.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=BW200_nn.json new file mode 100644 index 0000000..c10cb59 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=DCU_K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=DCU_K100_AI_nn.json new file mode 100644 index 0000000..9cde47e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=DCU_K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=K100_AI_nn.json new file mode 100644 index 0000000..9cde47e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=256,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..1d9d352 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..9ad5b31 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..2883dfd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..8abfd84 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..2fc18a5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..be8d4a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..71fdd88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW200.json new file mode 100644 index 0000000..1216e44 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW200_nn.json new file mode 100644 index 0000000..3a544e6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW3000.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW3000.json new file mode 100644 index 0000000..1216e44 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=BW3000.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI.json new file mode 100644 index 0000000..455f399 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI_nn.json new file mode 100644 index 0000000..85a471a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..b2799ed --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 0000000..c02de2f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json new file mode 100644 index 0000000..3e0bc75 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..9f7ed67 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..b8d3be2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..21b7255 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..eaf32f6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=BW200.json new file mode 100644 index 0000000..62214bf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=BW200.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=BW200_nn.json new file mode 100644 index 0000000..2eff61d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=BW200_nn.json @@ -0,0 +1,182 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=K100_AI.json new file mode 100644 index 0000000..4c0e8cc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=K100_AI.json @@ -0,0 +1,114 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1 + }, + "6144": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 1 + }, + "12288": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 6, + "num_warps": 4, + "num_stages": 1 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1 + }, + "32786": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json new file mode 100644 index 0000000..5a9910a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..b6f1d01 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..4bf7753 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..f245285 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..3918c93 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..3f3ccda --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,138 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..841044a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..59be497 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..0e5fd1e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..d6ad635 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..16e0a91 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..d766fc0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..8323f51 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..1b46cb5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..6d5b1ae --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..ffc1b23 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW200.json new file mode 100644 index 0000000..89cadc6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW200_nn.json new file mode 100644 index 0000000..bbd52d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW3000.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW3000.json new file mode 100644 index 0000000..89cadc6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=BW3000.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=K100_AI.json new file mode 100644 index 0000000..58eb11a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=K100_AI.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000..f4c0f84 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..5c8185c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..97c9f44 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..e4110a5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..0883ef4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..81bb765 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..811c77a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..2758e48 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..fc31215 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW200.json new file mode 100644 index 0000000..5db568d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW200_nn.json new file mode 100644 index 0000000..c4e3f92 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW3000.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW3000.json new file mode 100644 index 0000000..5db568d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=BW3000.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=DCU_K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=DCU_K100_AI_nn.json new file mode 100644 index 0000000..5bb8754 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=DCU_K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI.json new file mode 100644 index 0000000..70ee3ee --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI_nn.json new file mode 100644 index 0000000..1fd02cd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..0bb423b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..5557187 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..26bcbf2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..1a0aa33 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..9952be6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..379ca10 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..5a3f415 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..6cb80f4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..de9d0ab --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW200.json new file mode 100644 index 0000000..2c8dc11 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW200_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW200_nn.json new file mode 100644 index 0000000..c5f14fa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW200_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW3000.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW3000.json new file mode 100644 index 0000000..2c8dc11 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=BW3000.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI.json new file mode 100644 index 0000000..b2edb5e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI_nn.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI_nn.json new file mode 100644 index 0000000..0d7e6c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=K100_AI_nn.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 2, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 4, + "num_ldmatrixes": 1 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "num_ldmatrixes": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json new file mode 100644 index 0000000..b41f9d4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..edf2a38 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json new file mode 100644 index 0000000..32bbadb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..673bae2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..b2100ce --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..e6f753c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..53f3394 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json new file mode 100644 index 0000000..d720deb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json @@ -0,0 +1,173 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 7 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 128, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "192": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 128, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 8 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "6144": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 16, + "num_ctas": 1, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..48bb5f2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..a64d06c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..2c49f35 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..c7db6c0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..dbc6247 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..cc614e6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..32c0c9d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..4dd475c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..2ed15f3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..bd2c6fb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..8d7b780 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..7a07bbf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..3a3268c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=K100_AI.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=K100_AI.json new file mode 100644 index 0000000..1ae95f4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=K100_AI.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000..f578c8d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..918f683 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000..e341a67 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..eb81726 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..0c7062a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..cd4fb8f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000..cf66868 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json new file mode 100644 index 0000000..c27ca0a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json new file mode 100644 index 0000000..da477b1 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000..34b916e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..96cbc11 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/README b/vllm/model_executor/layers/fused_moe/configs/README new file mode 100644 index 0000000..85970e2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/README @@ -0,0 +1,12 @@ +This directory contains tuned configurations for different settings of the fused_moe kernel. +For different settings of +- E (number of experts) +- N (intermediate size) +- device_name (torch.cuda.get_device_name()) +the JSON file contains a mapping from M (batch size) to the chosen configuration. + +The example configurations provided are for the Mixtral model for TP2 on H100 +and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have +N = 7168 and for TP4 we have N = 3584. + +See `benchmark/kernels/benchmark_moe.py` on how to generate these config files. diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py new file mode 100644 index 0000000..e67ff66 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -0,0 +1,215 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +import torch + +from vllm import envs + + +class IPEXFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=envs.VLLM_CPU_MOE_PREPACK, + ) + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, + ) + + +class SGLFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + @staticmethod + def _grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use + # biased scores for expert selection but original scores for + # routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, + k=topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, + -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, + keepdim=True) + + return topk_weights, topk_ids.to(torch.int32) + + @staticmethod + def _select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = SGLFusedMOE._grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + topk_ids = topk_ids.to(torch.int32) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = SGLFusedMOE._select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + torch.ops._C.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + True, + False, + False, + None, + None, + None, + None, + None, + True, + ) + return x diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py new file mode 100644 index 0000000..0f41414 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -0,0 +1,645 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" CUTLASS based Fused MoE kernels.""" +from typing import Callable, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, + _fp8_quantize, + _resize_cache) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +def run_cutlass_moe_fp8( + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation_callable: Callable, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + out_dtype: torch.dtype, + per_act_token: bool, + per_out_ch: bool, + use_batched_format: bool, +): + a1q = hidden_states + + assert w1_scale is not None + assert w2_scale is not None + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1" + assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" + assert w1_scale.dim() == 1 or w1_scale.size( + 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.size( + 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert w1.size(0) == w2.size(0), "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( + 0) == 1 or a1q_scale.size( + 0) == a1q.shape[0], "Input scale shape mismatch" + assert w1.size(0) == w2.size(0), "Weights expert number mismatch" + assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" + assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" + assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( + 0) == 1 or a2_scale.size( + 0) == a1q.shape[0], "Intermediate scale shape mismatch" + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + if expert_map is not None: + assert expert_num_tokens is None + + # We have two modes: batched experts and non-batched experts. + # In the non-batched mode, the input tokens are not padded: thus, the shape + # of the input is [total_num_tokens, hidden_size]. The input and output + # require shuffling by a_map and c_map such that the tokens assigned to + # each expert are contiguous. + # In the batched mode, the input tokens are padded per expert to ensure that + # the batched dispatch and combine functions work correctly: thus, the shape + # of the input is [num_experts, max_num_tokens_per_expert, hidden_size]. + # The batched input and output require no shuffling by a_map and c_map since + # their tokens are already contiguous for each expert as a result of + # the dispatch function. + + M = a1q.size(0) # non batched expert M + padded_M = a1q.size(1) # batched expert M + _, K, N = w2.shape + device = a1q.device + + assert w1.size(2) == K + assert global_num_experts != -1 + assert a1q_scale is not None + + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.size(1) + local_E = w1.size(0) + + if use_batched_format: + assert expert_num_tokens is not None + + expert_offsets = torch.empty((local_E), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((local_E, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((local_E, 3), + dtype=torch.int32, + device=device) + + ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1, + problem_sizes2, expert_num_tokens, + local_E, padded_M, N, K) + + w1_scale = w1_scale.reshape(w1_scale.size(0), -1) + w2_scale = w2_scale.reshape(w2_scale.size(0), -1) + a1q = a1q.reshape(-1, a1q.size(2)) + a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() + + else: + expert_offsets = torch.empty((global_num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + if expert_map is not None: + a_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + else: + a_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + c_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, a_map, + c_map, global_num_experts, N, K) + + a1q = _fp8_perm(a1q, a_map) + a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale + expert_offsets = expert_offsets[:-1] + + ab_strides1 = torch.full((w1.size(0), ), + K, + device=device, + dtype=torch.int64) + c_strides1 = torch.full((w1.size(0), ), + 2 * N, + device=device, + dtype=torch.int64) + ab_strides2 = torch.full((w1.size(0), ), + N, + device=device, + dtype=torch.int64) + c_strides2 = torch.full((w1.size(0), ), + K, + device=device, + dtype=torch.int64) + + if use_batched_format: + c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) + c2 = _resize_cache(workspace2, (local_E * padded_M, N)) + c3 = _resize_cache(workspace13, (local_E * padded_M, K)) + else: + c1 = _resize_cache(workspace13, (M * topk, N * 2)) + c2 = _resize_cache(workspace2, (M * topk, N)) + c3 = _resize_cache(workspace13, (M * topk, K)) + + c1.fill_(0) + + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, + problem_sizes1, ab_strides1, ab_strides1, c_strides1, + per_act_token, per_out_ch) + + activation_callable(c2, c1) + + a2q, a2q_scale = ops.scaled_fp8_quant( + c2, a2_scale, use_per_token_if_dynamic=per_act_token) + + if expert_map is not None: + c3.fill_(0) + + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, + problem_sizes2, ab_strides2, ab_strides2, c_strides2, + per_act_token, per_out_ch) + + if use_batched_format: + output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) + else: + # We can't do this inplace because output may point to the same tensor + # as c3. + output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) + + +# TODO (bnell): split class batched vs. non-batched? +# maybe remove need for passing aq to workspace_shapes +class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_experts_per_worker: int, + out_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, + num_dispatchers: Optional[int] = None, + use_batched_format: bool = False, + ): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + )) + assert max_experts_per_worker > 0 + assert not use_batched_format or num_dispatchers is not None + self.max_experts_per_worker = max_experts_per_worker + self.num_dispatchers = num_dispatchers + self.out_dtype = out_dtype + self.use_batched_format = use_batched_format + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.use_batched_format: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + else: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return not self.use_batched_format + + def supports_expert_map(self) -> bool: + return not self.use_batched_format + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1: tuple[int, ...] = () + workspace2: tuple[int, ...] = () + output: tuple[int, ...] = () + if self.use_batched_format: + padded_M = aq.size(1) + num_dp = self.num_dispatchers + assert num_dp is not None + workspace1 = (self.max_experts_per_worker, padded_M * num_dp, + max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, + (N // 2)) + output = (self.max_experts_per_worker, padded_M, K) + else: + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M * topk, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + activation_callable = lambda i, o: self.activation(activation, i, o) + in_dtype = hidden_states.dtype + run_cutlass_moe_fp8( + output, hidden_states, w1, w2, topk_ids, activation_callable, + global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, expert_num_tokens, + self.out_dtype if self.out_dtype is not None else in_dtype, + self.per_act_token_quant, self.per_out_ch_quant, + self.use_batched_format) + + +def cutlass_moe_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + per_act_token: Optional[bool] = None, + activation: str = "silu", + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mappings. + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. + - global_num_experts (int): The total number of experts. + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + if per_act_token is None: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + per_out_ch = w1_scale.numel() != w1_q.size(0) + + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( + 0) + + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp8( + max_experts_per_worker=num_experts, + out_dtype=a.dtype, + per_act_token_quant=per_act_token, + per_out_ch_quant=per_out_ch, + use_batched_format=False, + ), + ) + + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + False, + activation, + num_experts, + expert_map, + w1_scale, + w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, + device: torch.device): + """ + MoE implementation for FP4 Inputs + + # Gemm 1 + a: Input tensor: [m, k] (half/bfloat16) + a1_gscale: Activation scale per expert: [e] (float32) + w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] + w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) + (Note: `n` is the up projection output dim, `k` is the input dim in + full precision) + w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) + (Block size = 16 for NVFP4) + + # Gemm 2 + a2_gscale: Activation scale per expert: [e] + w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] + w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) + w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 + + topk_weights: [m, topk] dtype: float8 + topk_ids: [m, topk] dtype: float8 + + m, n, k: Unquantized weight shapes, dtype: int + e: number of experts, dtype: int + + assumes that topk < k < n to satisfy - up/down projection expectations. + """ + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" + assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" + assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 + and w2_blockscale.ndim + == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + m_a, k_a = a.shape + e_w1, nx2_w1, half_k_w1 = w1_fp4.shape + e_w2, k_w2, half_n_w2 = w2_fp4.shape + + assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", + " between weights.") + assert (k_a // 2 == half_k_w1 + and k == k_w2), ("Hidden size mismatch between a, w1 and w2") + assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " + "expected `n`") + assert (m == m_a), "input shape mismatch" + assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert (topk_weights.size(0) == m and topk_ids.size(0) + == m), ("topk must be provided for each row of a") + + out_dtype = a.dtype + num_topk = topk_ids.size(1) + + expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,2n,k)) + problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,n,k)) + problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + # problem shapes should have [m, n, k] + # Note that problem sizes are based on logical number of elements. + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, e, n, k, + blockscale_offsets) + + a = ops.shuffle_rows(a, a_map) + + rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( + a, + a1_gscale, + expert_offsets, + blockscale_offsets, + num_topk, + ) + + c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1], + out_dtype, device) + del rep_a_fp4, rep_a_blockscale + # hidden size dimension is split to one halfpytho sized tensor. + intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), + device=device, + dtype=out_dtype) + + torch.ops._C.silu_and_mul(intermediate, c1) + + int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( + intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + + c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1], out_dtype, device) + del int_fp4, int_blockscale + + c2 = ops.shuffle_rows(c2, c_map) + out = (c2.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).half()).sum(dim=1) + return out.to(dtype=out_dtype) + + +def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor) -> bool: + + def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): + return M >= 128 and N % 128 == 0 and K % 128 == 0 + + m = hidden_states.size(0) + _, K, N = w2.size() + if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") + return False + + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") + return False + + return True + + +def run_cutlass_block_scaled_fused_experts( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + w1_q = w1.transpose(1, 2) + w2_q = w2.transpose(1, 2) + w1_scale = w1_scale.transpose(1, 2) + w2_scale = w2_scale.transpose(1, 2) + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert a.shape[0] == topk_ids.shape[ + 0], "a and topk_ids must have the same batch size" + assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" + assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" + assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" + assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" + assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[ + 0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[ + 0], "w2_scale expert number mismatch" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + + out_dtype = a.dtype + num_experts = w1_q.size(0) + m = a.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + + expert_offsets = torch.empty((num_experts + 1, ), + dtype=torch.int32, + device="cuda") + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device="cuda") + + topk = topk_ids.size(1) + + a_q, a1_scale = _fp8_quantize(a, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) + device = a_q.device + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + num_experts, + n, + k, + ) + + rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) + rep_a1_scales = a1_scale[a_map] + + c1 = torch.empty((m * topk, n * 2), dtype=out_dtype, device=device) + c2 = torch.empty((m * topk, k), dtype=out_dtype, device=device) + + ops.cutlass_blockwise_scaled_grouped_mm( + c1, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + problem_sizes1, + expert_offsets[:-1], + ) + + intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) + torch.ops._C.silu_and_mul(intermediate, c1) + + intermediate_q, a2_scale = _fp8_quantize(intermediate, + A_scale=None, + per_act_token=False, + block_shape=[128, 128]) + + ops.cutlass_blockwise_scaled_grouped_mm( + c2, + intermediate_q, + w2_q, + a2_scale, + w2_scale, + problem_sizes2, + expert_offsets[:-1], + ) + + return (c2[c_map].view(m, topk, k) * + topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py new file mode 100644 index 0000000..8ad57c2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, per_token_group_quant_fp8) +from vllm.utils import has_deep_gemm, round_up + +logger = init_logger(__name__) + + +@functools.cache +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + +def _valid_deep_gemm_shape(M: int, N: int, K: int): + align = deep_gemm_block_shape()[0] + return align <= M and N % align == 0 and K % align == 0 + + +def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor) -> bool: + """ + Check if the given problem size is supported by the DeepGemm grouped + gemm kernel. All of M, N, K and the quantization block_shape must be + aligned by `dg.get_m_alignment_for_contiguous_layout()`. + """ + if not has_deep_gemm(): + logger.debug("DeepGemm disabled: deep_gemm not available.") + return False + + M = hidden_states.size(0) + _, K, N = w2.size() + if not _valid_deep_gemm_shape(M, N, K): + logger.debug("DeepGemm disabled: unaligned problem size.") + return False + + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug("DeepGemm disabled: invalid weight dtype(s).") + return False + + if (not hidden_states.is_contiguous() or not w1.is_contiguous() + or not w2.is_contiguous()): + logger.debug( + "DeepGemm disabled: weights or activations not contiguous.") + return False + + return True + + +class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_act_token_quant=False, + block_shape=deep_gemm_block_shape(), + )) + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert self.block_shape is not None + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + num_experts = global_num_experts + block_m = self.block_shape[0] + M_sum = (M * topk) + num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + workspace1 = (M_sum, max(N * 2, K)) + workspace2 = (M_sum, max(N, K)) + output = (M * topk, K) + return (workspace1, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + import deep_gemm as dg + assert self.block_shape is not None + + a1q = hidden_states + _, N, K = w1.size() + + if global_num_experts == -1: + global_num_experts = w1.size(0) + + assert w2.size(1) == K + + a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( + a1q, + a1q_scale, + topk_ids, + global_num_experts, + expert_map, + self.block_shape[0], + ) + + if expert_map is not None: + # DeepGemm (Grouped Contiguous) kernel needs a valid B index + # for all rows of A. To that effect, simply compute with + # the 0th weight matrix. + # Note that this relies on the fact that corresponding topk + # weights would be 0 during weight multiplication. + expert_ids = torch.where(expert_ids == -1, 0, expert_ids) + + # Note: M_sum is different than the pre-permuted shape of a1q. + M_sum = a1q.size(0) + + mm1_out = _resize_cache(workspace13, (M_sum, N)) + act_out = _resize_cache(workspace2, (M_sum, N // 2)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (M_sum, N // 2)) + mm2_out = _resize_cache(workspace2, (M_sum, K)) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) + + self.activation(activation, act_out, mm1_out.view(-1, N)) + + a2q_scale: Optional[torch.Tensor] = None + a2q, a2q_scale = per_token_group_quant_fp8(act_out, + self.block_shape[1], + column_major_scales=True, + out_q=quant_out) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) + + torch.index_select(mm2_out, 0, inv_perm, out=output) + + +def deep_gemm_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input=False, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with DeepGemm + grouped gemm. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1 (torch.Tensor): The first set of fp8 quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2 (torch.Tensor): The second set of fp8 quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + + Returns: + - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. + """ + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + DeepGemmExperts(), + ) + return fn( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py new file mode 100644 index 0000000..b625c28 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import deep_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using DeepEP High-Throughput kernels. + """ + + def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, + dp_size: int, rank_expert_offset: int): + super().__init__() + self.buffer = buffer + self.num_dispatchers_ = num_dispatchers + self.dp_size = dp_size + self.rank_expert_offset = rank_expert_offset + # The dispatch function returns a handle that the combine function + # requires. We store the handle here so it is available to the + # combine function. + self.handle = None + + # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 + self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.int64 + + def _get_dispatch_config(self) -> Optional[deep_ep.Config]: + if self.dp_size not in self.available_rank_configs: + return None + return deep_ep.Buffer.get_dispatch_config(self.dp_size) + + def _get_combine_config(self) -> Optional[deep_ep.Config]: + if self.dp_size not in self.available_rank_configs: + return None + return deep_ep.Buffer.get_combine_config(self.dp_size) + + def _do_dispatch(self, tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, num_experts: int): + + has_scales = token_scales is not None + + (num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, + is_token_in_rank, event) = self.buffer.get_dispatch_layout( + topk_idx=rank_topk_ids, + num_experts=num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + + token_data = tokens + if has_scales: + token_data = (tokens, token_scales) + + ( + token_data, expert_topk_ids, expert_topk_weights, + expert_num_tokens_per_expert_list, self.handle, event + ) = self.buffer.dispatch( + x=token_data, + handle=None, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=expert_num_tokens, + topk_idx=rank_topk_ids, + topk_weights=rank_topk_weights, + # expert_alignment rounds the number of tokens per expert + # to this value. + expert_alignment=1, + config=self._get_dispatch_config(), + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + + if has_scales: + expert_x, expert_x_scale = token_data + else: + expert_x, expert_x_scale = token_data, None + + # The existing MOE kernels assume that all entries of topk_ids are + # valid. To that effect, set the -1s in expert_topk_ids to some expert + # outside this rank so the expert_map can remap it to -1 when safe. + # With Expert Parallel, the experts are divided amongst the rank + # sequentially. For rank 0, set it to num_experts - 1 and for all other + # ranks set it to 0 as we know that expert_map will have a -1 in those + # regions for those ranks. + # + # DeepEP's topk_ids output refers to the local experts directly. Offset + # the topk_ids to move it back to the global experts space so it aligns + # with existing vLLM interfaces. + expert_topk_ids = torch.where( + expert_topk_ids == -1, + num_experts - 1 if self.rank_expert_offset == 0 else 0, + expert_topk_ids + self.rank_expert_offset) + + return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + expert_topk_weights) + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + if quant_config.per_act_token_quant: + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=True, + block_shape=quant_config.block_shape, + ) + if a1q_scale is not None and a1q_scale.numel() == 1: + a1q_scale = a1q_scale.view(1, 1) + (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + expert_topk_weights) = self._do_dispatch( + tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts) + else: + # DeepEP kernels only support dispatching per-token-quant + # quantization. dispatch in bfloat16. + (expert_x, _, expert_num_tokens, expert_topk_ids, + expert_topk_weights) = self._do_dispatch( + tokens=a1, + token_scales=None, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts) + # quantize now + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, + a1_scale, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=False, + block_shape=quant_config.block_shape) + + return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + expert_topk_weights) + + def _apply_weights_and_reduce(self, num_tokens: int, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + apply_router_weight_on_input: bool, + output_dtype: torch.dtype): + + hidden_dim = fused_expert_output.size(-1) + if fused_expert_output.ndim == 2: + fused_expert_output = fused_expert_output.view( + num_tokens, -1, hidden_dim) + + if not apply_router_weight_on_input: + # The DeepEP combine kernels don't do the topk weight + # multiplication. We multiply the weights locally. + m_x_topk = fused_expert_output.size(0) + fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1)) + + out = torch.empty((num_tokens, hidden_dim), + device=fused_expert_output.device, + dtype=output_dtype) + ops.moe_sum(fused_expert_output, out) + + return out + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> None: + + assert self.handle is not None + + # fused_expert_output can have 0 tokens - This happens when none of the + # tokens from the all2all reach this EP rank. + if fused_expert_output.numel() != 0: + fused_expert_output = self._apply_weights_and_reduce( + num_tokens=topk_ids.size(0), + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + apply_router_weight_on_input=apply_router_weight_on_input, + output_dtype=output.dtype) + + combined_x, _, event = self.buffer.combine( + x=fused_expert_output, + handle=self.handle, + topk_weights=None, + config=self._get_combine_config(), + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py new file mode 100644 index 0000000..78ac4ac --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import deep_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input, normalize_batched_scales_shape) + +# DeepEP kernels quantize dispatch inputs in 128 element chunks. +DEEPEP_QUANT_BLOCK_SIZE = 128 +DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] + + +def dequant_fp8(expert_x_fp8: torch.Tensor, + expert_x_scales: torch.Tensor) -> torch.Tensor: + """ + Return dequantized tensor in fp32 + """ + # TODO (varun) : Optimize leverage num_tokens_per_expert counts + assert expert_x_fp8.is_contiguous() + expert_x_scales = expert_x_scales.contiguous() + num_experts = expert_x_fp8.size(0) + + expert_x_fp32 = expert_x_fp8.to(torch.float32).view( + num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) + expert_x_scales = expert_x_scales.view(num_experts, -1, 1) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) + + +class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using DeepEP low-latency kernels. + """ + + # DeepEP low-latency kernels are compiled only for certain + # specific hidden sizes. + SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168] + + def __init__(self, + buffer: deep_ep.Buffer, + max_tokens_per_rank: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False): + super().__init__() + + self.buffer = buffer + self.max_tokens_per_rank = max_tokens_per_rank + self.use_fp8_dispatch = use_fp8_dispatch + # The dispatch function returns a handle that the combine function + # requires. We store the handle here so it is available to the + # combine function. + self.handle = None + self.num_dispatchers_ = num_dispatchers + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_tokens_per_rank + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.int64 + + def _do_quant( + self, + x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + block_k = block_shape[1] if block_shape is not None else None + if self.use_fp8_dispatch: + if block_k == DEEPEP_QUANT_BLOCK_SIZE: + # DeepEP kernels did the quantization for us. + x, x_scales = x + return x, x_scales + + # Dequant to get back the tokens in the datatype we dispatched in. + x_fp8, x_scales = x + x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) + + assert isinstance(x, torch.Tensor) + + num_experts, max_tokens, hidden_dim = x.size() + + # TODO (varun): Optimization - Use a batched version of quant + x = x.view((-1, hidden_dim)) + x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, + per_act_token_quant, + block_shape) + x = x.view((num_experts, -1, hidden_dim)) + + if quant_dtype is not None: + assert x_scales is not None + x_scales = normalize_batched_scales_shape(x_scales, num_experts) + + return x, x_scales + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + hidden_size = a1.size(1) + assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ + (f"Hidden Size {hidden_size} not in supported list of hidden sizes" + f"{self.SUPPORTED_HIDDEN_SIZES}") + + if self.use_fp8_dispatch: + assert hidden_size % 128 == 0, \ + "DeepEP kernels quantize the inputs in blocks of shape 128" + + has_per_token_scales = a1_scale.numel( + ) != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + assert not has_per_token_scales, ( + "low_latency kernels doesn't support dispatching per-token scales") + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + # Dispatch + expert_x, expert_num_tokens, self.handle, event, hook = \ + self.buffer.low_latency_dispatch(a1, + topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + async_finish=False, + return_recv_hook=False) + + expert_x, expert_x_scale = self._do_quant( + expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) + + return (expert_x, expert_x_scale, expert_num_tokens, None, None) + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> None: + + assert self.handle is not None + + combine_topk_weights = topk_weights + if apply_router_weight_on_input: + # weights have already been applied. + combine_topk_weights = torch.ones_like(topk_weights) + + # TODO (varun) : Enable zero copy mode + _, event, hook = self.buffer.low_latency_combine( + fused_expert_output, + topk_ids, + combine_topk_weights, + self.handle, + async_finish=False, + zero_copy=False, + return_recv_hook=False, + out=output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py new file mode 100644 index 0000000..0355abb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -0,0 +1,1021 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused batched MoE kernel.""" +from typing import Optional + +import torch +import triton +import triton.language as tl + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, + normalize_scales_shape) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + group_broadcast) + + +@triton.jit +def moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # Offsets and masks + offs_m, + offs_n, + offs_bn, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, +): + + offs_k = tl.arange(0, BLOCK_K) + + if use_w8a16: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn + + # per act token + elif per_act_token_quant: + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] + + b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=mask_m, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel( + a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am: tl.int64, + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # offsets + offs_bn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + # Make grids of a + b pointers + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + accumulator = moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + offs_bn, + mask_m, + # Block size for block-wise quantization + group_n, + group_k, + # Meta-parameters + BLOCK_M, + BLOCK_N, + BLOCK_K, + compute_type, + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant) + + # store in C + offs_cn = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae: tl.int64, + stride_am: tl.int64, + stride_ak: tl.int64, + stride_be: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_ce: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + # axis 1 is M_blocks * N_blocks + pid_mn = tl.program_id(axis=1) + #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + + cta_n_start * stride_cn) + + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N + + if use_fp8_w8a8: + a_scale_ptr = a_scale_ptr + expert_id * stride_ase + b_scale_ptr = b_scale_ptr + expert_id * stride_bse + + # block-wise + if group_k > 0 and group_n > 0 or per_act_token_quant: + a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm + + expert_triton_kernel( + a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # offsets + offs_bn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel( + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None): + + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config['BLOCK_SIZE_M'] + BLOCK_N = config['BLOCK_SIZE_N'] + BLOCK_K = config['BLOCK_SIZE_K'] + + grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * + triton.cdiv(B.size(1), BLOCK_N)) + + A_scale = normalize_batched_scales_shape(A_scale, + expert_num_tokens.shape[0]) + + if B_scale is not None and B_scale.ndim == 1: + assert B_scale.numel() == expert_num_tokens.shape[0] + B_scale = B_scale.view(-1, 1, 1) + + assert A_scale is None or A_scale.ndim == 3, ( + f"{0 if A_scale is None else A_scale.shape}") + assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, ( + f"{0 if B_scale is None else B_scale.shape}") + + if B_scale is not None: + if B_scale.ndim == 1: + stride_bse = 1 + stride_bsk = 0 + stride_bsn = 0 + else: + stride_bse = B_scale.stride(0) + stride_bsk = B_scale.stride(2) + stride_bsn = B_scale.stride(1) + + else: + stride_bse = 0 + stride_bsk = 0 + stride_bsn = 0 + + if A_scale is not None: + stride_ase = A_scale.stride(0) + stride_asm = A_scale.stride(1) + stride_ask = A_scale.stride(2) + else: + stride_ase = 0 + stride_asm = 0 + stride_ask = 0 + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K) + + +class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + A reference prepare/finalize class that reorganizes the tokens into + expert batched format, i.e. E x max_num_tokens x K. This is the format + that the PPLX dispatch/combine kernels use. + """ + + def __init__( + self, + max_num_tokens: int, + num_local_experts: int, + num_dispatchers: int, + rank: int, + ): + super().__init__() + self.max_num_tokens = max_num_tokens + self.num_local_experts = num_local_experts + self.rank = rank + self.num_dispatchers_ = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_num_tokens + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + assert a1.dim() == 2 + assert topk_ids.dim() == 2 + assert topk_ids.size(0) == a1.size(0) + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + num_tokens, hidden_dim = a1.size() + topk = topk_ids.size(1) + + tokens_per_expert = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) + + num_local_experts = self.num_local_experts + + if quant_config.quant_dtype is None: + b_type = a1.dtype + else: + b_type = quant_config.quant_dtype + + b_a1 = torch.zeros( + (num_local_experts, self.max_num_tokens, hidden_dim), + dtype=b_type, + device=a1.device) + + if quant_config.is_quantized: + scale_shape = quant_config.batched_scale_shape( + num_local_experts, self.max_num_tokens, hidden_dim) + + b_a1_scale = torch.empty(scale_shape, + dtype=torch.float32, + device=a1.device) + else: + assert a1_scale is None + b_a1_scale = None + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + a1_scale = normalize_scales_shape(a1_scale) + a2_scale = normalize_scales_shape(a2_scale) + + for expert_id in range(first_expert, last_expert): + topks = torch.any(topk_ids == expert_id, dim=1).flatten() + rows = torch.count_nonzero(topks.flatten()) + if rows == 0: + continue + idx = expert_id - first_expert + tokens_per_expert[idx] = rows + rhs = a1[:topks.numel()][topks] + if quant_config.quant_dtype is not None: + if a1_scale is not None: + if quant_config.is_per_act_token: + rhs_a1_scale = a1_scale[:topks.numel()][topks] + else: + rhs_a1_scale = a1_scale + else: + rhs_a1_scale = None + b_a1[idx, :rows, :], b_s = moe_kernel_quantize_input( + rhs, + rhs_a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + assert b_s is not None + if quant_config.is_per_act_token: + b_a1_scale[idx, :rows] = b_s[:rows] + else: + b_a1_scale[idx, :b_s.shape[0]] = b_s + else: + b_a1[idx, :rows, :] = rhs + + assert b_a1_scale is None or b_a1_scale.ndim == 3 + + return b_a1, b_a1_scale, tokens_per_expert, None, None + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + num_tokens = topk_ids.size(0) + num_local_experts = fused_expert_output.size(0) + K = fused_expert_output.size(-1) + assert output.size(0) == num_tokens and output.size(1) == K + + output.fill_(0) + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + matching_tokens = topk_ids == expert_id + topks = torch.any(matching_tokens, dim=1).flatten() + rows = torch.count_nonzero(topks) + rhs = fused_expert_output[expert_id - first_expert, :rows, :] + if not apply_router_weight_on_input: + rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) + output[topks] = output[topks] + rhs + + +class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + """ + A reference MoE expert class that operates on expert batched format, + i.e. E x max_num_tokens x K. This is the format that the pplx + dispatch/combine kernels use. + """ + + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[list[int]] = None, + per_act_token_quant: bool = False, + ): + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert a.dim() == 2 + num_dp = self.num_dispatchers + num_experts = local_num_experts + workspace13 = (num_experts, self.max_num_tokens * num_dp, K) + workspace2 = (self.max_num_tokens * num_dp, N) + output = workspace13 + return (workspace13, workspace2, output, a.dtype) + + def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + assert self.quant_config.is_quantized + f32 = torch.float32 + if (self.quant_config.is_per_act_token + or self.quant_config.is_per_tensor): + return t.to(f32) * scale + else: + return t.to(f32) * group_broadcast(scale, t.shape) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + assert hidden_states.dim() == 3 + assert expert_num_tokens is not None + + num_local_experts = w1.size(0) + assert num_local_experts == w1.size(0), ( + f"{num_local_experts} == {w1.size(0)}") + + N = w1.size(1) // 2 + + for expert in range(num_local_experts): + # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor + if (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): + num = hidden_states.shape[1] + else: + num = int(expert_num_tokens[expert].item()) + + if num == 0: + continue + + tmp = _resize_cache(workspace2, (num, N)) + + if self.quant_config.is_quantized: + assert a1q_scale is not None and w1_scale is not None + input = self.dequant(hidden_states[expert, :, :], + a1q_scale[expert]) + w1_dq = self.dequant(w1[expert], w1_scale[expert]) + input = input[:num] @ w1_dq.transpose(0, 1) + else: + input = hidden_states[expert, :num, :] @ w1[expert].transpose( + 0, 1) + + self.activation(activation, tmp, input.to(tmp.dtype)) + + if self.quant_config.is_quantized: + assert w2_scale is not None + w2_dq = self.dequant(w2[expert], w2_scale[expert]) + else: + w2_dq = w2[expert] + + output[expert, :num, :] = tmp @ w2_dq.transpose(0, 1).to(tmp.dtype) + + +def batched_moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + num_tokens: int, + E: int, + N: int, + expert_num_tokens: torch.Tensor, + qtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if (torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing()): + # Note: this does a bunch of extra work because expert_num_tokens is + # ignored but it does support torch.compile + cudagraphs. + hidden_dim = A.size(-1) + assert A_scale is None or A_scale.ndim <= 2, ( + f"{A_scale.shape if A_scale is not None else None}") + A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, + hidden_dim), A_scale, + qtype, per_act_token_quant, + block_shape) + A_q = A_q.view(E, -1, hidden_dim) + A_q_scale = normalize_batched_scales_shape(A_q_scale, E) + + return A_q, A_q_scale + elif qtype is None: + return A, normalize_batched_scales_shape(A_scale, E) + else: + A_q = torch.empty_like(A, dtype=qtype) + + if per_act_token_quant: + assert block_shape is None + scale_shape = (E, num_tokens, 1) + elif block_shape is not None: + _, block_k = block_shape + k_tiles = (A.shape[-1] + block_k - 1) // block_k + scale_shape = (E, num_tokens, k_tiles) + else: + scale_shape = (E, 1, 1) + + A_q_scale = torch.zeros(scale_shape, + dtype=torch.float32, + device=A.device) + + num_experts = expert_num_tokens.numel() + + A_scale = normalize_batched_scales_shape(A_scale, num_experts) + + for e in range(E): + num_tokens = int(expert_num_tokens[e].item()) + if num_tokens > 0: + if A_scale is not None: + scales = A_scale[e, :min(num_tokens, A_scale.shape[1])] + else: + scales = None + A_q[e, :num_tokens], tmp_scale = moe_kernel_quantize_input( + A[e, :num_tokens], + scales, + qtype, + per_act_token_quant, + block_shape, + ) + assert tmp_scale is not None + A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale + + return A_q, A_q_scale + + +class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + """ + A Triton based MoE expert class that operates on expert batched format, + i.e. E x max_num_tokens x K. This is the format that the pplx + dispatch/combine kernels use. + """ + + def __init__( + self, + max_num_tokens: int, + num_dispatchers: int, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, + ): + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + assert max_num_tokens > 0 + assert num_dispatchers > 0 + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + assert a.dim() == 2 + num_dp = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = self.max_num_tokens + workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + output = (num_experts, max_num_tokens * num_dp, K) + return (workspace13, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.size(-1) // 2 == w1.size(2), ( + "Hidden size mismatch") + else: + assert hidden_states.size(-1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(-1)} " + f"!= {w1.size(2)}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + assert w1.size(0) == E + assert w2.size(0) == E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.size(), + w2.size(), + top_k_num, + config_dtype, + max_num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, + (E, max_num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, + (E, max_num_tokens, N // 2)) + + if self.use_fp8_w8a8: + intermediate_cache1.fill_(0) + + a1q_scale = normalize_batched_scales_shape(a1q_scale, E) + + # MM1 + invoke_moe_batched_triton_kernel( + A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape) + + intermediate_cache2.fill_(0) + + # TODO (bnell): use triton utility from batched deep gemm. + self.activation(activation, intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) + + qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( + intermediate_cache2, a2_scale, max_num_tokens, E, N, + expert_num_tokens, self.quant_dtype, self.per_act_token_quant, + self.block_shape) + + invoke_moe_batched_triton_kernel( + A=qintermediate_cache2, + B=w2, + C=output, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + per_act_token_quant=self.per_act_token_quant, + block_shape=self.block_shape) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py new file mode 100644 index 0000000..51e0e77 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Optional + +import torch +try: + import lightop +except Exception: + print("INFO: Please install lightop if you want to infer awq of marlin.\n") + +import vllm.envs as envs +import vllm._custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import ( + moe_align_block_size, try_get_optimal_moe_config) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_make_workspace_new, maybe_warn_marlin_atomic_add) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils import direct_register_custom_op +from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache +def get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + +def fused_marlin_moe( + hidden_states: torch.Tensor, # 32, 7168 + w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256 + w2: torch.Tensor, # 256, 256, 7168 + w1_scale_zero: torch.Tensor, + w2_scale_zero: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 4, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. + - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. + - sort_indices1 (Optional[torch.Tensor]): The first act_order input + permutation. + - sort_indices2 (Optional[torch.Tensor]): The second act_order input + permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. + - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # quant_type = ScalarType.from_id(quant_type_id) + # assert quant_type in [ + # scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, + # scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f + # ] + + # bit4_scalar_types = [ + # scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f + # ] + # num_bits = 4 if quant_type in bit4_scalar_types else 8 + + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + # assert num_bits in [4] + assert num_bits in [4] + + num_tokens, K = hidden_states.shape # 32, 7168 + E = w1.shape[0] # 256 + N = w2.shape[1] * 16 # 256 + topk = topk_ids.shape[1] # 8 + + #暂时固定为16384 + #CHUNK_SIZE = 16384 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + + if workspace is None: + sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count + workspace = torch.zeros(sms * 3, + dtype=torch.int, + device=hidden_states.device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + if global_num_experts == -1: + global_num_experts = E + intermediate_cache2 = torch.empty( + (M * topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if envs.VLLM_USE_GLOBAL_CACHE13: + intermediate_cache13 = get_moe_cache(topk, N, K, device=hidden_states.device, dtype=hidden_states.dtype) + else: + intermediate_cache13 = torch.empty( + (M * topk * max(2 * N, K), ), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = intermediate_cache13[:M * topk * 2 * N] + intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) + intermediate_cache3 = intermediate_cache13[:M * topk * K] + intermediate_cache3 = intermediate_cache3.view(-1, K) + + use_atomic_add = hidden_states.dtype == torch.half or \ + torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.size() + + if tokens_in_chunk == 0: + break + intermediate_cache3 = intermediate_cache3.view(-1, K) + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk * topk, :] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk, :] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk * topk, :] + M = tokens_in_chunk + + # Select block_size_m + for block_size_m in [16, 32, 48, 64, 80]: + if M * topk / E / block_size_m < 0.9: + break + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(curr_topk_ids, block_size_m, global_num_experts, expert_map) + + intermediate_cache1 = lightop.moe_marlin_w4a16( + curr_hidden_states, + intermediate_cache1, + w1, + w1_scale_zero, + g_idx1, + sort_indices1, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + curr_topk_weights, + block_size_m, + topk, + False, + expert_map is not None, + M, + 2 * N, + K, + is_k_full, + use_atomic_add, + True, + False + ) + + torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1) + + intermediate_cache3 = lightop.moe_marlin_w4a16( + intermediate_cache2, + intermediate_cache3, + w2, + w2_scale_zero, + g_idx2, + sort_indices2, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + curr_topk_weights, + block_size_m, + 1, + True, + expert_map is not None, + M * topk, + K, + N, + is_k_full, + use_atomic_add, + True, + False + ).view(-1, topk, K) + + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) + + return out_hidden_states + +def fused_marlin_moe_fake( + hidden_states: torch.Tensor, # 32, 7168 + w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256 + w2: torch.Tensor, # 256, 256, 7168 + w1_scale_zero: torch.Tensor, + w2_scale_zero: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + num_bits: int = 4, + is_k_full: bool = True, + inplace: bool = False) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="fused_marlin_moe", + op_func=fused_marlin_moe, + mutates_args=[], + fake_impl=fused_marlin_moe_fake, +) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 0000000..57a7147 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,2084 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused MoE kernel.""" +import functools +import json +import os +import math +from typing import Any, Callable, Dict, Optional, List, Optional, Tuple + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops +from vllm.logger import init_logger +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, get_config_quant_dtype) +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + _valid_cutlass_block_scaled_grouped_gemm, + run_cutlass_block_scaled_fused_experts) +# yapf: enable +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) + +try: + from lmslim.layers.gemm.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) + from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json) + from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8 +except Exception: + print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") + +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils import direct_register_custom_op + +# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled + +logger = init_logger(__name__) +if envs.VLLM_USE_GLOBAL_CACHE13: + moe_cache_singleton = None +def get_moe_cache(top_k_num,N,K,device,dtype): + global moe_cache_singleton + if moe_cache_singleton is None: + moe_cache_singleton = torch.empty(envs.VLLM_FUSED_MOE_CHUNK_SIZE * top_k_num *max(N, K), device=device, dtype=dtype) + logger.info(f"Initializing moe_cache_singleton shape: {moe_cache_singleton.shape}, memory: {moe_cache_singleton.element_size() * moe_cache_singleton.numel() / 1024**2:.2f} MB") + return moe_cache_singleton + +@triton.jit +def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, + token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, + compute_type): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +@triton.jit +def fused_moe_kernel_awq( + # Pointers to matrices + a_ptr, # [4, 7168] + b_ptr, # [256, 512, 3584] + c_ptr, # (8, 8, 512) + b_scale_ptr, # (256, 512, 56) + b_zp_ptr, # (256, 256, 56) + topk_weights_ptr, + sorted_token_ids_ptr, # [0, 1, 2, 3, 4] + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, # pading后的总索引长度 + num_valid_tokens, # 有效索引的上限 + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, #1 + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk,#1 + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, # 128 + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m] + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N)) % N # [block_n] + offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k] + offs_k2 = tl.arange(0, BLOCK_SIZE_K // 2) # 0, 1, 2, ...... , 127 # # [block_k] + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) # [block_m, block_k] + + off_experts = tl.load(expert_ids_ptr + pid_m) + + if use_int4_w4a16: + # [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63] + # [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127] + + # b_ptrs = b_ptr + off_experts * stride_be + \ + # (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_bn[:, None] * stride_bn + (offs_k2[None, :]) * stride_bk + # tl.device_print("stride_bn",stride_bsn)>1 + # tl.device_print("stride_bk",stride_bk)=1 + b_shifter = (offs_k[:, None] % 2) * 4 # 0, 4 + elif use_int8_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 # 0, 4 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = tl.interleave(b, b) + b= b.trans() + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ + offs_bn[None, :] * stride_bsk + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsn + qzeros_scles = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + + scales_int16 = tl.cast(qzeros_scles,tl.uint16) + b_scale = tl.cast(scales_int16,tl.float16,bitcast=True) + # tl.device_print("b_scale dequant",b_scale) + + mid = qzeros_scles >> 16 + # b_zp = tl.cast(mid,tl.float16,bitcast=False) + b_zp = tl.cast(mid,tl.float16) + # b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False) + + # tl.device_print("bzp",b_zp) + + # We accumulate along the K dimension. + b = ((b - b_zp) * b_scale).to(tl.float16) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + + if use_int4_w4a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ + offs_bn[None, :] * stride_bsn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + (offs_bn[None, :] // 2) * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + offs_bn[None, :] * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # num_pid_in_group = GROUP_SIZE_M * num_pid_n + # group_id = pid // num_pid_in_group + # first_pid_m = group_id * GROUP_SIZE_M + # group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + # pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + # pid_n = (pid % num_pid_in_group) // group_size_m + if GROUP_SIZE_M ==1: + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, + offs_token, token_mask, BLOCK_SIZE_M, + BLOCK_SIZE_N, compute_type) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8 or use_int8_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + + offs_bsn * stride_bsn) + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, + None] + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=token_mask, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + if use_fp8_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def invoke_fused_moe_kernel(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + use_int4_w4a8: bool, + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, + use_nn_moe: Optional[bool]=False) -> None: + assert topk_weights is not None or not mul_routed_weight + assert topk_weights is None or topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8 or use_int8_w8a8: + assert B_scale is not None + assert (block_shape is None + or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) + assert (block_shape is None + or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + M = A.size(0) + num_tokens = M * top_k + + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.size(0), + A.size(0) * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( + B.size(1) if not use_nn_moe else B.size(2), META['BLOCK_SIZE_N']), ) + + if (use_int8_w8a16 or use_int4_w4a16) and \ + block_shape is not None and block_shape[1] > 0: + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + if os.environ.get('moe_wna16_use_cuda') == '1': + use_moe_wna16_cuda = should_moe_wna16_use_cuda( + num_valid_tokens=num_tokens, + group_size=block_shape[1], + num_experts=B.size(0), + bit=4 if use_int4_w4a16 else 8) + + config = config.copy() + config.update( + get_moe_wna16_block_config(config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"])) + + if use_moe_wna16_cuda: + bit = 4 if use_int4_w4a16 else 8 + ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, expert_ids, + num_tokens_post_padded, top_k, + config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], bit) + return + + if os.environ.get('AWQ_MOE_SZ') == '1': + fused_moe_kernel_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + A.size(1), + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + else: + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + A.size(1), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + else: + # config = config.copy() + # BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") + # if block_shape is not None: + # BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], + # block_shape[1])) + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1) if not use_nn_moe else B.size(2), + A.size(1), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2) if not use_nn_moe else B.stride(1), + B.stride(1) if not use_nn_moe else B.stride(2), + C.stride(1), + C.stride(2), + A_scale.stride(0) + if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) + if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) + if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) + if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) + if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + **config, + ) + + +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +def get_config_file_name(E: int, + N: int, + dtype: Optional[str], + block_shape: Optional[List[int]] = None, use_nn_moe: Optional[bool] = False) -> str: + device_name = current_platform.get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + block_shape_selector = ("" if not block_shape or not all(block_shape) else + f",block_shape={block_shape}").replace(" ", "") + if not use_nn_moe: + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 + else: + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json" + +# Adapted from: https://github.com/sgl-project/sglang/pull/2628 +@functools.lru_cache +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = None, + block_k: Optional[int] = None, + use_nn_moe: Optional[bool] = False, +) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + block_shape = [block_n, block_k] if block_n and block_k else None + json_file_name = get_config_file_name(E, N, dtype, block_shape, use_nn_moe=use_nn_moe) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120: + config_file_path_120 = config_file_path.replace(".json","_120.json") + if os.path.exists(config_file_path_120): + with open(config_file_path_120) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path_120) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ("Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s"), config_file_path) + return None + + +def get_moe_wna16_block_config(config: dict[str, + int], use_moe_wna16_cuda: bool, + num_valid_tokens: int, size_k: int, size_n: int, + num_experts: int, group_size: int, + real_top_k: int, block_size_m: int): + if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: + # optimal block config is set + return {} + if not use_moe_wna16_cuda: + # triton moe wna16 kernel + if num_valid_tokens // real_top_k == 1: + # if bs=1, use a smaller BLOCK_SIZE_N + return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} + else: + return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} + else: + # cuda moe wna16 kernel + # set default block_size 128, and increase them when num_blocks + # is too large. + block_size_n = 128 + block_size_k = 128 + if block_size_k <= group_size: + block_size_k = group_size + + num_n_blocks = size_k // block_size_k + num_k_blocks = size_n // block_size_k + num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ + num_experts + if num_valid_tokens // real_top_k <= block_size_m: + num_m_blocks = min(num_m_blocks, num_valid_tokens) + num_blocks = num_m_blocks * num_n_blocks * num_k_blocks + + if size_k % 256 == 0 and num_blocks >= 256 and \ + block_size_k < 256: + block_size_k = 256 + num_blocks = num_blocks // (256 // block_size_k) + + if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ + size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ + num_blocks >= 512: + block_size_k = block_size_k * 2 + num_blocks = num_blocks // 2 + + if num_blocks > 1024: + block_size_n = 256 + num_n_blocks = num_n_blocks // 2 + num_blocks = num_blocks // 2 + + if size_n <= 1024 and num_blocks >= 1024: + # The kernel performance got much better with BLOCK_SIZE_N=1024 + # when num_blocks is large, event when N is small. + # Not sure why, maybe it force the CUDA SM process only one block + # at the same time. + block_size_n = 1024 + + return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} + + +def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, + num_experts: int, bit: int): + return bit == 4 and group_size in [32, 64, 128] and \ + num_valid_tokens / num_experts <= 6 + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool]=False, +) -> dict[str, int]: + if dtype == "fp8_w8a8" and block_shape is not None: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] + # BLOCK_SIZE_K must be divisible by block_shape[1] + # num_stages=3 can cause triton.runtime.errors.OutOfResources + # on ROCm, set it to 2 instead. + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 if not current_platform.is_rocm() else 2, + } + + # elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: + # # moe wna16 kernels + # # only set BLOCK_SIZE_M + # # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later + # bit = 4 if dtype == "int4_w4a16" else 8 + # use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, + # block_shape[1], E, bit) + # if use_moe_wna16_cuda: + # config = {"BLOCK_SIZE_M": min(16, M)} + # elif M <= 20: + # config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} + # elif M <= 40: + # config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} + # else: + # config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + elif is_marlin: + for block_size_m in [8, 16, 32, 48, 64]: + if M * topk / E / block_size_m < 0.9: + break + return {"BLOCK_SIZE_M": block_size_m} + elif M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + + if use_nn_moe: + config["num_ldmatrixes"] = 1 + return config + + +def try_get_optimal_moe_config( + w1_shape: tuple[int, ...], + w2_shape: tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False, +) -> dict[str, int]: + from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() + if override_config: + config = override_config + else: + # First try to load optimal config from the file + if not use_nn_moe: + E, _, N = w2_shape + else: + E, N, _ = w2_shape + # if dtype == "int4_w4a16": + # N = N * 2 + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2] if not use_nn_moe else w1_shape[1], top_k, dtype, + is_marlin, block_shape, use_nn_moe=use_nn_moe) + return config + + +def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + ops.topk_softmax( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_indices + + +def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: + # if is_rocm_aiter_moe_enabled(): + # from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax + # return rocm_aiter_topk_softmax + return vllm_topk_softmax + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + indices_type: Optional[torch.dtype] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + M, _ = hidden_states.size() + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device) + token_expert_indices = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + + gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. + + topk_func = dispatch_topk_func() + topk_weights, topk_ids = topk_func(topk_weights, topk_ids, + token_expert_indices, + gating_output_float, renormalize) + + return topk_weights, topk_ids, token_expert_indices + + +def is_power_of_two(n): + return n > 0 and math.log2(n).is_integer() + + +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.size(0) + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def get_config_dtype_str( + dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False, + use_int8_w8a8: Optional[bool] = False, + use_int4_w4a8: Optional[bool] = False) -> Optional[str]: + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a8: + return "int8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w4a16" + elif use_int4_w4a8: + return "int4_w4a8" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def inplace_fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: Optional[str] = None, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False) -> None: + fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, + activation, apply_router_weight_on_input, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, + per_channel_quant, global_num_experts, expert_map, + w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape, use_nn_moe) + + +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: Optional[str] = None, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False) -> None: + pass + + +direct_register_custom_op( + op_name="inplace_fused_experts", + op_func=inplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=inplace_fused_experts_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + +def outplace_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: Optional[str] = None, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False) -> torch.Tensor: + return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, + False, activation, apply_router_weight_on_input, + use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, + use_int4_w4a16,use_int4_w4a8, per_channel_quant, + global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, + block_shape, use_nn_moe) + + +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: Optional[str] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="outplace_fused_experts", + op_func=outplace_fused_experts, + mutates_args=[], + fake_impl=outplace_fused_experts_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + +def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: + torch.ops.vllm.inplace_fused_experts(**kwargs) + hidden_states = kwargs['hidden_states'] + return hidden_states + + +def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: + return torch.ops.vllm.outplace_fused_experts(**kwargs) + + +def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: + if inplace: + return torch_vllm_inplace_fused_experts + return torch_vllm_outplace_fused_experts + + +# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace +# torch ops. +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False, + allow_cutlass_block_scaled_grouped_gemm: bool = False, + use_nn_moe: Optional[bool] = False) -> torch.Tensor: + # For now, disable DeepGemm for small N (<= 512) until better + # permute/unpermute ops are available. + N = w1.size(1) + if (allow_deep_gemm and use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2)): + assert apply_router_weight_on_input is False + return deep_gemm_moe_fp8( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 + and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + assert apply_router_weight_on_input is False + return run_cutlass_block_scaled_fused_experts( + a=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids) + else: + return dispatch_fused_experts_func(inplace)( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + use_nn_moe=use_nn_moe) + + +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False, +) -> torch.Tensor: + num_tokens = hidden_states.size(0) + if use_nn_moe: + E, _, N = w1.size() + else: + E, N, _ = w1.size() + K = w2.size(1) + + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.size(1) + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + + if envs.VLLM_USE_GLOBAL_CACHE13: + cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype) + else: + cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) + if use_int8_w8a8 is True: + return fused_experts_impl_int8(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + cache13 = cache13, + inplace=inplace, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=False, + use_int8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + use_nn_moe=False + ) + elif use_int4_w4a8 is True: + return fused_experts_impl_w4a8(hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=inplace, + cache13 = cache13, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8= False, + use_int8_w8a8= False, + use_int8_w8a16= False, + use_int4_w4a16 = False, + use_int4_w4a8 = True, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + use_nn_moe= False + ) + + # + if use_int4_w4a16: + assert hidden_states.size(1) // 2 == w1.size(2), ( + "Hidden size mismatch") + elif use_nn_moe: + assert hidden_states.size(1) == w1.size(1), "Hidden size mismatch" + else: + assert hidden_states.size(1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") + + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + dtype=hidden_states.dtype) + + qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.size(), + w2.size(), + top_k_num, + config_dtype, + block_shape=block_shape, + use_nn_moe=use_nn_moe, + ) + + config = get_config_func(M) + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2]) + + # This needs separate memory since it's used concurrently with cache1 + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.size() + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.size(1)] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + if not use_int8_w8a8: + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( + A=curr_hidden_states, + A_scale=a1_scale, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, + block_shape=block_shape) + + if use_int4_w4a16: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map, curr_hidden_states.shape[0])) + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + use_nn_moe=use_nn_moe) + + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + A=intermediate_cache2, + A_scale=a2_scale, + quant_dtype=qtype, + per_act_token_quant=per_channel_quant, + block_shape=block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + use_nn_moe=use_nn_moe) + + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) + + return out_hidden_states + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + activation: str = "silu", + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool =False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + use_nn_moe: Optional[bool] = False, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + elif custom_routing_function is None: + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) + + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + use_nn_moe=use_nn_moe) + + +class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + use_int4_w4a8: bool = False, + per_act_token_quant: bool = False, + block_shape: Optional[List[int]] = None, + ): + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int8_w8a16 = use_int8_w8a16 + self.use_int4_w4a8 = use_int4_w4a8 + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1 = (M, topk, max(N * 2, K)) + workspace2 = (M, topk, N) + output = (M, topk, K) + return (workspace1, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.size(-1) // 2 == w1.size(2), ( + "Hidden size mismatch") + else: + assert hidden_states.size(-1) == w1.size(2), \ + (f"Hidden size mismatch {hidden_states.size(-1)} " + f"!= {w1.size(2)}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert hidden_states.dim() == 2 + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + if global_num_experts == -1: + global_num_experts = E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + use_int4_w4a8=self.use_int4_w4a8, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.size(), + w2.size(), + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, + (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, + (num_tokens * top_k_num, N // 2)) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + use_int4_w4a8=self.use_int4_w4a8, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape) + + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) + + a2q_scale: Optional[torch.Tensor] = None + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.quant_dtype, + self.per_act_token_quant, self.block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + output, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + use_int4_w4a8= self.use_int4_w4a8, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape) + + +def modular_triton_fused_moe( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + use_int4_w4a8: bool, + per_act_token_quant: bool, + block_shape: Optional[List[int]] = None, +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonExperts( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_int4_w4a8=use_int4_w4a8, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ), + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py new file mode 100644 index 0000000..6e828a4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,1643 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import importlib + + +from abc import abstractmethod +from collections.abc import Iterable +from enum import Enum +from typing import Callable, Literal, Optional, overload + +import torch +import torch.nn.functional as F +from torch.nn.parameter import UninitializedParameter + +import vllm.envs as envs +from vllm.config import get_current_vllm_config +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.distributed.eplb.eplb_state import EplbState +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +# yapf: disable +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig) +# yapf: enable +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEActivationFormat, FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) +# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( +# is_rocm_aiter_moe_enabled) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, grouped_topk, is_power_of_two) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.platforms.interface import CpuArchEnum + +from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm import _custom_ops as ops +from lightop import op + +if current_platform.is_cuda_alike(): + from .fused_batched_moe import BatchedTritonExperts + from .fused_moe import TritonExperts, fused_experts + if has_pplx(): + from .pplx_prepare_finalize import (PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes) + if has_deep_ep(): + from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize + from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, + DeepEPLLPrepareAndFinalize) +else: + fused_experts = None # type: ignore + FusedMoEPermuteExpertsUnpermute = None # type: ignore + FusedMoEPrepareAndFinalize = None # type: ignore + +# if is_rocm_aiter_moe_enabled(): + # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + # rocm_aiter_grouped_topk as grouped_topk) +if current_platform.is_cpu(): + pass +else: + from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +if current_platform.is_tpu(): + from .moe_pallas import fused_moe as fused_moe_pallas +else: + fused_moe_pallas = None # type: ignore + +logger = init_logger(__name__) + + +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + BLOCK = "block" + + +class FusedMoEMethodBase(QuantizeMethodBase): + + moe: FusedMoEConfig + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + raise NotImplementedError + + def init_prepare_finalize(self, moe: FusedMoEConfig, + quant_config: Optional[QuantizationConfig]): + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + self.moe = moe + + prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + + if moe.use_pplx_kernels: + hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( + moe.max_num_tokens, + moe.hidden_dim, + moe.in_dtype, + moe.quant_dtype, + per_act_token_quant=moe.per_act_token_quant, + block_shape=moe.block_shape, + ) + + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=hidden_dim_bytes, + hidden_dim_scale_bytes=hidden_scale_bytes, + ) + + num_dispatchers = (all2all_manager.world_size // + all2all_manager.tp_group.world_size) + + # Intranode pplx a2a takes a group name while internode does not. + if not all2all_manager.internode: + all_to_all_args[ + "group_name"] = all2all_manager.cpu_group.group_name + + handle = all2all_manager.get_handle(all_to_all_args) + + prepare_finalize = PplxPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + num_local_experts=moe.num_local_experts, + num_dispatchers=num_dispatchers, + ) + elif moe.use_deepep_ht_kernels: + assert moe.dp_size == all2all_manager.dp_world_size + + all_to_all_args = dict() + handle = all2all_manager.get_handle(all_to_all_args) + prepare_finalize = DeepEPHTPrepareAndFinalize( + handle, + num_dispatchers=all2all_manager.world_size, + dp_size=all2all_manager.dp_world_size, + rank_expert_offset=all2all_manager.rank * + moe.num_local_experts, + ) + + elif moe.use_deepep_ll_kernels: + all_to_all_args = dict( + max_num_tokens_per_dp_rank=moe.max_num_tokens, + token_hidden_size=moe.hidden_dim, + num_ep_ranks=all2all_manager.world_size, + num_global_experts=moe.num_experts, + num_local_experts=moe.num_experts // + all2all_manager.world_size) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note : We may want to use FP8 dispatch even otherwise just to + # reduce datamovement + use_fp8_dispatch = (moe.quant_config is not None + and moe.quant_config.quant_dtype + == current_platform.fp8_dtype() + and moe.quant_config.block_shape + == DEEPEP_QUANT_BLOCK_SHAPE) + + # Note (varun): Whether to use FP8 dispatch or not needs some + # profiling. Turning it off for now. + prepare_finalize = DeepEPLLPrepareAndFinalize( + handle, + max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + ) + + self.topk_indices_dtype = None + if prepare_finalize is not None: + logger.debug("%s", prepare_finalize.__class__.__name__) + self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() + experts = self.select_gemm_impl(prepare_finalize, moe) + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + # based on the all2all implementation, select the appropriate + # gemm implementation + raise NotImplementedError( + f"{self.__class__.__name__} must select appropriate gemm " + "implementation based on the prepare_finalize") + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +@CustomOp.register("unquantized_fused_moe") +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def __init__(self, moe: FusedMoEConfig): + super().__init__() + self.fused_experts = fused_experts # type: ignore + self.topk_indices_dtype = None + self.moe = moe + + self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled() + if self.rocm_aiter_moe_enabled: + from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts + else: + self.rocm_aiter_fused_experts = None # type: ignore + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + + assert self.fused_experts == fused_experts + + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + logger.debug("BatchedTritonExperts %s", self.moe) + return BatchedTritonExperts( + max_num_tokens=self.moe.max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + ) + else: + logger.debug("TritonExperts %s", self.moe) + return TritonExperts() + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, use_nn_moe: bool, + **extra_weight_attrs): + # Fused gate_up_proj (column parallel) + if not use_nn_moe: + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + else: + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + if not use_nn_moe: + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + else: + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # Padding the weight for better performance on ROCm + layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) + # Lazy import to avoid importing triton. + # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + # shuffle_weights) + + # if self.rocm_aiter_moe_enabled: + # shuffled_w13, shuffled_w2 = shuffle_weights( + # layer.w13_weight.data, layer.w2_weight.data) + + # layer.w13_weight.data = shuffled_w13 + # layer.w2_weight.data = shuffled_w2 + + if current_platform.is_cpu(): + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + dtype = layer.w13_weight.dtype + if (envs.VLLM_CPU_SGL_KERNEL + and torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) + else: + raise NotImplementedError("CPU MOE only supports x86 arch.") + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + + return self.forward( + x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + global_num_experts=global_num_experts, + expert_map=expert_map, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_nn_moe=use_nn_moe, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + ) -> torch.Tensor: + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate) + + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input) + else: + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + use_nn_moe=use_nn_moe + ) + + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **kwargs, + ): + return layer.cpu_fused_moe( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + global_num_experts, + expert_map, + custom_routing_function, + scoring_func, + e_score_correction_bias, + apply_router_weight_on_input, + activation, + ) + + def forward_hpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + ) -> torch.Tensor: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert layer is not None + assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for HPU.") + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for HPU.") + return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, + router_logits, top_k) + + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + assert apply_router_weight_on_input is False + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for TPU.") + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for TPU.") + assert activation == "silu", f"{activation} is not supported for TPU." + return fused_moe_pallas(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize) + + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + else: + forward_native = forward_cuda + + +def determine_expert_map( + ep_size: int, ep_rank: int, + global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]: + """ + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. + + Args: + ep_size (int): The size of the expert parallel group + global_num_experts (int): The total number of experts in the model. + + Returns: + tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ + assert ep_size > 0 + if ep_size == 1: + return (global_num_experts, None) + + local_num_experts = global_num_experts // ep_size + + # Create a tensor of size num_experts filled with -1 + expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) + # Create a expert map for the local experts + if ep_rank < (ep_size - 1): + # Each non-last rank gets local_num_experts experts. + expert_map[ep_rank * local_num_experts: + (ep_rank + 1) * local_num_experts] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + else: + # All remaining experts are assigned to the last rank. + local_num_experts = (global_num_experts - ep_rank * local_num_experts) + + expert_map[-local_num_experts:] = \ + torch.arange(0, local_num_experts, dtype=torch.int32) + return (local_num_experts, expert_map) + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + enable_eplb: Whether to enable expert parallelism load balancer. + """ + + def __init__( + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + num_redundant_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + ): + super().__init__() + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + tp_size_ = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + dp_size_ = (dp_size + if dp_size is not None else get_dp_group().world_size) + + vllm_config = get_current_vllm_config() + self.moe_parallel_config: FusedMoEParallelConfig = ( + FusedMoEParallelConfig.make( + tp_size_=tp_size_, + dp_size_=dp_size_, + vllm_parallel_config=vllm_config.parallel_config)) + + self.global_num_experts = num_experts + num_redundant_experts + + # For smuggling this layer into the fused moe custom op + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError("Duplicate layer name: {}".format(prefix)) + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + + self.enable_eplb = enable_eplb + self.expert_load_view: Optional[torch.Tensor] = None + self.logical_to_physical_map: Optional[torch.Tensor] = None + self.logical_replica_count: Optional[torch.Tensor] = None + + # Determine expert maps + if self.use_ep: + if self.enable_eplb: + assert self.global_num_experts % self.ep_size == 0, \ + "EPLB currently only supports even distribution of " \ + "experts across ranks." + else: + assert num_redundant_experts == 0, \ + "Redundant experts are only supported with EPLB." + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + else: + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) + + self.top_k = top_k + + assert intermediate_size % self.tp_size == 0 + self.hidden_size = hidden_size + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + self.apply_router_weight_on_input = apply_router_weight_on_input + self.activation = activation + self.routed_scaling_factor = routed_scaling_factor + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") + if current_platform.is_hpu(): + from vllm_hpu_extension.ops import DynamicFusedMOE + self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + + if vllm_config.model_config is not None: + model_dtype = vllm_config.model_config.dtype + else: + # TODO (bnell): This is a hack to get test_mixtral_moe to work + # since model_config is not set in the pytest test. + model_dtype = params_dtype + + moe = FusedMoEConfig.make( + num_experts=self.global_num_experts, + experts_per_token=top_k, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + moe_parallel_config=self.moe_parallel_config, + in_dtype=model_dtype, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, + quant_config=quant_config, + ) + self.moe_config = moe + self.quant_config = quant_config + + # Note: get_quant_method will look at the layer's local_num_experts + # for heuristic purposes, so it must be initialized first. + quant_method: Optional[QuantizeMethodBase] = None + quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None + else quant_config.get_quant_method(self, prefix)) + + assert quant_method is not None + assert isinstance(quant_method, FusedMoEMethodBase) + self.quant_method = quant_method + + if self.enable_eplb: + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8MoEMethod) + if not isinstance(quant_method, Fp8MoEMethod): + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError("EPLB is only supported for FP8 " + "quantization for now.") + + if quant_config is None: + # Not considering quant for now, temporarily + self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1 + else: + self.use_nn_moe = False + + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + "use_nn_moe": self.use_nn_moe, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod", + "SlimQuantW4A8Int8MoEMethod", + "SlimQuantW4A8Int8MarlinMoEMethod")): + moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition + + + self.quant_method.create_weights(layer=self, **moe_quant_params) + + from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce + self.tbo_all_reduce = tbo_all_reduce + + # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion. + self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \ + and self.e_score_correction_bias is not None \ + and self.global_num_experts // num_expert_group <= 32 \ + and is_power_of_two(e_score_correction_bias.shape[0]) + + # Chunked all2all staging tensor + self.batched_hidden_states: Optional[torch.Tensor] = None + self.batched_router_logits: Optional[torch.Tensor] = None + if (self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels): + self.batched_hidden_states = torch.zeros( + (moe.max_num_tokens, self.hidden_size), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + # Note here we use `num_experts` which is logical expert count + self.batched_router_logits = torch.zeros( + (moe.max_num_tokens, num_experts), + dtype=moe.in_dtype, + device=torch.cuda.current_device()) + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ + if shard_id == "w2": + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + load_full=load_full_w2) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim, + shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + if not self.use_nn_moe: + expert_data.copy_(loaded_weight) + else: + expert_data.copy_(loaded_weight.T) + + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim, + shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + if not self.use_nn_moe: + expert_data.copy_(loaded_weight) + else: + expert_data.copy_(loaded_weight.T) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if not self.use_nn_moe: + param_data[expert_id] = loaded_weight + else: + param_data[expert_id] = loaded_weight.T + + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + if not self.use_nn_moe: + expert_data.copy_(loaded_weight) + else: + expert_data.copy_(loaded_weight.T) + + def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: + if self.expert_map is None: + return expert_id + return self.expert_map[expert_id].item() + + @overload + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int, + return_success: Literal[False]) -> None: + ... + + @overload + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int, + return_success: Literal[True]) -> bool: + ... + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False) -> Optional[bool]: + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + # Failed to load this param since it's not local to this rank + return False if return_success else None + # Hereafter, `expert_id` is local physical id + + quant_method_name = self.quant_method.__class__.__name__ + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality + if self.quant_method.__class__.__name__ in ( + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod"): + loaded_weight = loaded_weight.t().contiguous() + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size_per_partition is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + expert_data = param.data[expert_id] + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + param.data.copy_(loaded_weight) + return True if return_success else None + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size_per_partition is + is_transposed = getattr(param, "is_transposed", False) or self.use_nn_moe + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = int(not shard_dim) + + full_load = len(loaded_weight.shape) == 3 + if full_load: + shard_dim += 1 + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if shard_id in ["w1", "w3"]: + final_shape[1] *= 2 + final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + expert_data = param.data if full_load else param.data[expert_id] + + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + + if ("compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return True if return_success else None + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return True if return_success else None + + # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern + if "ModelOpt" in quant_method_name: + if ('weight_scale_2' in weight_name + or 'input_scale' in weight_name): + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + elif "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return True if return_success else None + + # Case weight scales, zero_points and offset, weight/input global scales + if ("scale" in weight_name or "zero" in weight_name + or "offset" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full_w2=getattr(param, "load_full_w2", False)) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return True if return_success else None + + # Case weight_shape + if "weight_shape" in weight_name: + # only required by compressed-tensors + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return True if return_success else None + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return True if return_success else None + + return False if return_success else None + + def get_expert_weights(self) -> Iterable[torch.Tensor]: + weights = list(self.named_parameters()) + assert all(weight.is_contiguous() for _, weight in weights) + + # Filter out the non-expert weights. + # `e_score_correction_bias` is a bias for each logical expert, + # with shape (num_logical_experts,), not an expert weight. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + } + + return [ + weight.view(self.local_num_experts, -1) for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + ] + + def set_eplb_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state in this layer. + + This is used later in forward pass, where we get the expert mapping + and record the load metrics in `expert_load_view`. + """ + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] + + @staticmethod + def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + enable_eplb: bool = False, + expert_map: Optional[torch.Tensor] = None, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): + The weights and *global physical* expert ids of the top-k experts. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ + from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk + + # DeepSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + if use_fused_gate: + if envs.VLLM_USE_LIGHT_OP: + topk_weights, topk_ids = op.moe_fused_gate( + router_logits, + e_score_correction_bias, + num_expert_group, + topk_group, + top_k, + 0, + routed_scaling_factor, + ) + else: + topk_weights, topk_ids = ops.moe_fused_gate( + router_logits, + e_score_correction_bias, + num_expert_group, + topk_group, + top_k, + routed_scaling_factor=routed_scaling_factor, + n_share_experts_fusion=0, + ) + else: + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + elif custom_routing_function is None: + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + indices_type=indices_type, + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) + + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # TODO: maybe optimize this by using specified kernels, + # or compute pseudo-random indices by modulo + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids_long]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_logical_experts,) + + # Mask out non-local experts + if expert_map is not None: + topk_ids_local = expert_map[topk_ids] + topk_ids_flatten = topk_ids_local.flatten() + else: + topk_ids_flatten = topk_ids.flatten() + + # Should be equivalent to: + # ``` + # topk_ids_masked = topk_ids_local[topk_ids_local >= 0] + # expert_load_view += topk_ids_masked.bincount( + # minlength=expert_load_view.shape[0]) + # ``` + # We use `scatter_add_` since `bincount` cannot be compiled + + # Performance optimization: + # `masked_fill` is significantly faster than `masked_select` + invalid_mask = topk_ids_flatten < 0 + # Replace invalid expert ids with 0 (just a dummy position) + # to avoid out-of-bounds errors in scatter_add_ + index = topk_ids_flatten.masked_fill_(invalid_mask, 0) + # `src` is the valid mask, which is 1 for valid and 0 for invalid + src = ~invalid_mask + + expert_load_view.scatter_add_(dim=0, + index=index.long(), + src=src.to(expert_load_view)) + + topk_ids = topk_ids.to(dtype=indices_type) + + assert topk_ids.dtype == indices_type or indices_type is None + + return topk_weights, topk_ids + + def must_reduce_shared_expert_outputs(self) -> bool: + """ + The shared_experts are typically computed using the RowParallelLinear + layer. The result of this function is typically used as + the reduce_results argument to the module. + When just tensor-parallel is used, it is not required to reduce + the shared_experts results immediately. Instead we reduce at the + once at the end of the MoE op. (Refer to DeepSeekV2MoE module) + With EP and all2all kernels - this is no longer viable as all + GPU ranks in DP, produce the complete set of hidden_states. + Therefore it is required that we reduce the shared_experts output + early. + """ + return (self.use_pplx_kernels or self.use_deepep_ht_kernels + or self.use_deepep_ll_kernels) + + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): + """ + The pplx combine kernel reduces across GPU ranks by default. + """ + if (self.use_pplx_kernels or self.use_deepep_ht_kernels + or self.use_deepep_ll_kernels): + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + # TODO: Once the OOM issue for the TPU backend is resolved, we will + # switch to using the moe_forward custom op. + if current_platform.is_tpu(): + return self.forward_impl(hidden_states, router_logits) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) + + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor): + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + assert self.batched_hidden_states.dtype == full_hidden_states.dtype + assert self.batched_router_logits.dtype == full_router_logits.dtype + # Check size compatibility. + assert ( + self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)) + assert ( + self.batched_router_logits.size(-1) == full_router_logits.size(-1)) + + full_final_hidden_states = torch.empty_like(full_hidden_states) + + def process_chunk(chunk_start, chunk_end, skip_result_store=False): + chunk_size = chunk_end - chunk_start + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] + + assert (self.batched_hidden_states.size(0) # type: ignore + >= chunk_size) + assert (self.batched_router_logits.size(0) # type: ignore + >= chunk_size) + staged_hidden_states = self.batched_hidden_states[: + chunk_size, :] # type: ignore + staged_router_logits = self.batched_router_logits[: + chunk_size, :] # type: ignore + staged_hidden_states.copy_(hidden_states, non_blocking=True) + staged_router_logits.copy_(router_logits, non_blocking=True) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=staged_hidden_states, + router_logits=staged_router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, + ) + + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states, non_blocking=True) + + ctx = get_forward_context() + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) + + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) + + return full_final_hidden_states + + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + if (self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels): + return self.forward_impl_chunked(hidden_states, router_logits) + + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 + and not self.moe_parallel_config.use_deepep_ht_kernels) + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, + use_nn_moe=self.use_nn_moe, + routed_scaling_factor=self.routed_scaling_factor, + use_fused_gate=self.use_fused_gate + ) + + if do_naive_dispatch_combine: + final_hidden_states = get_ep_group().combine(final_hidden_states) + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs. + if envs.VLLM_ENABLE_TBO: + final_hidden_states = self.tbo_all_reduce(final_hidden_states) + else: + final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: + + num_physical_experts = num_experts + num_redundant_experts + + # In the returned mapping: + # - `expert_id` is the physical expert id + # - `weight_name` contains the weight name of the logical expert + # So that we should map the expert id to logical in `weight_name` + physical_to_logical_map = \ + EplbState.build_initial_global_physical_to_logical_map( + num_experts, num_redundant_experts) + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + expert_id, shard_id) for expert_id in range(num_physical_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + + def extra_repr(self) -> str: + + s = ( + f"global_num_experts={self.global_num_experts}, " + f"local_num_experts={self.local_num_experts}, " + f"top_k={self.top_k}, " + f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501 + f"tp_size={self.tp_size},\n" + f"ep_size={self.ep_size}, " + f"reduce_results={self.reduce_results}, " + f"renormalize={self.renormalize}, " + f"use_grouped_topk={self.use_grouped_topk}") + + if self.use_grouped_topk: + s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 + + s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501 + + return s + + +def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, + layer_name: str) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + assert self.quant_method is not None + + return self.forward_impl(hidden_states, router_logits) + + +def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, + layer_name: str) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="moe_forward", + op_func=moe_forward, + mutates_args=["hidden_states"], + fake_impl=moe_forward_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py new file mode 100644 index 0000000..f332b51 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -0,0 +1,598 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from enum import Enum +from math import prod +from typing import Optional, final + +import torch + +import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.utils import cdiv + +# +# This file defines a set of base classes used to make MoE kernels more modular. +# The goal is to be able to utilize different communication mechanisms with +# any fused MoE kernel without needing to have combinatoric implementations. +# +# The fused moe kernels are broken down into the following components: +# +# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] +# +# Each component will be independent of the others except for +# [Quantize-Dispatch] and `[Combine] (see below). The components can then be +# mixed and matched with so that DP+EP can be supported easily for multiple +# MoE kernel implementations. +# +# The following main classes are defined: +# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE +# inputs (e.g. quantization, distribution) and finalization of Moe outputs. +# The prepare method must take care of any needed quantization and the +# finalize method must apply weights and do the final reduction of the output. +# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# MoE operation. One important feature to note is that this class does not +# apply topk weights or reduce the final output. +# * FusedMoEModularKernel - an interface class that combines a +# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to +# provide the standard fused MoE kernel interface. +# +# [Quantize-Prepare] and [Finalize] functionality are bundled into a single +# class `FusedMoEPrepareAndFinalize` since they could use collective +# communication mechanisms that need to be consistent. +# + + +def _moe_problem_size( + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, +) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation tensors is + not obvious. It needs to be done this way specifically due to subtle issues + with particular kernels, e.g. the int4 kernels divide the trailing dimension + by two, so it's not "correct" to extract N or K from the trailing dimension + of w1 or w2. Similarly, some kernels transpose the weights, so this needs + to be kept in mind. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.size() + K = w2.size(1) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), \ + f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + + +class FusedMoEActivationFormat(Enum): + """ + The standard activation format (num_tokens, hidden dim). + """ + Standard = "standard", + """ + The batched experts format (num experts, max tokens per expert, hidden dim) + """ + BatchedExperts = "batched_experts", + + +# TODO: pass FusedMoEParallelConfig in as ctor parameter? +class FusedMoEPrepareAndFinalize(ABC): + """ + An abstract base class for the [Quantize-Prepare] and [Finalize] steps + described above. + """ + + @abstractmethod + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform any quantization (and/or) dispatching needed + for this kernel. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. + - topk_ids: The topk ids. + - topk_weights: The topk weights. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. + + Returns a tuple of: + - quantized + dispatched a. + - quantized + dispatched a1_scales. + - Optional tensor as big as number of local experts that contains the + number of tokens assigned to each local expert. + - Optional dispatched expert topk IDs + - Optional dispatched expert topk weight + """ + raise NotImplementedError + + @abstractmethod + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + """ + raise NotImplementedError + + @property + @abstractmethod + def activation_format(self) -> FusedMoEActivationFormat: + """ + A property indicating the output format of the activations for the + 'prepare' method. + """ + raise NotImplementedError + + @abstractmethod + def topk_indices_dtype(self) -> Optional[torch.dtype]: + """ + The PrepareFinalize All2All implementations generally constrain the + dtype of the topk_ids they support. This function returns the + required topk indices dtype so it can be respected. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def max_num_tokens_per_rank(self) -> Optional[int]: + """ + Some PrepareFinalize All2All implementations are batched. Meaning, + they can processes only as set of tokens at a time. This + function returns the batch size i.e the maximum number of tokens + the implementation can process at a time. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def num_dispatchers(self) -> int: + raise NotImplementedError + + +class FusedMoEPermuteExpertsUnpermute(ABC): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ + + def __init__( + self, + quant_config: Optional[FusedMoEQuantConfig], + ): + if quant_config is not None: + self.quant_config = quant_config + else: + self.quant_config = FusedMoEQuantConfig() + + @property + @abstractmethod + def activation_formats( + self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + """ + A property which is a tuple of the input and output activation formats + for the 'apply' method. + """ + raise NotImplementedError + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + return self.quant_config.quant_dtype + + @property + def block_shape(self) -> Optional[list[int]]: + return self.quant_config.block_shape + + @property + def per_act_token_quant(self) -> bool: + return self.quant_config.per_act_token_quant + + @property + def per_out_ch_quant(self) -> bool: + return self.quant_config.per_out_ch_quant + + # TODO (bnell): make this return a CHUNK_SIZE or None instead? + @abstractmethod + def supports_chunking(self) -> bool: + """ + A flag indicating whether or not this class supports activation + chunking. + """ + raise NotImplementedError + + @abstractmethod + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps + """ + raise NotImplementedError + + @abstractmethod + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + raise NotImplementedError + + def activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor) -> None: + assert output.size(-1) * 2 == input.size(-1) + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(output, input) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + def enable_chunking(self): + return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ + self.supports_chunking() + + @abstractmethod + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + """ + This function computes the intermediate result of a Mixture of Experts + (MoE) layer using two sets of weights, w1 and w2. + + Parameters: + - output: (torch.Tensor): The unweighted, unreduced output tensor. + - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE + layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_ids (torch.Tensor): A map of row to expert id. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be + used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs + must be large enough to hold output of either MoE gemm. + - workspace2 (torch.Tensor): A scratch tensor used for the activation + function. + - expert_num_tokens: An optional tensor containing the number of tokens + assigned to each expert when using batched experts format input. + """ + raise NotImplementedError + + +def _chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + +@final +class FusedMoEModularKernel(torch.nn.Module): + """ + This class combines a FusedMoEPrepareAndFinalize instance and + a FusedMoEPermuteExpertsUnpermute to provide an interface that + is compatible with the `fused_experts` function in fused_moe.py. + + It takes care of managing any required scratch space. + + Note: Instances of this class should only be used for a single model + layer due to any layer specific state that may be used by the component + objects. + """ + + def __init__( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + fused_experts: FusedMoEPermuteExpertsUnpermute, + ): + super().__init__() + self.prepare_finalize = prepare_finalize + self.fused_experts = fused_experts + assert prepare_finalize.activation_format == \ + fused_experts.activation_formats[0], ( + f"{prepare_finalize.__class__.__name__}." + f"{prepare_finalize.activation_format} == " + f"{fused_experts.__class__.__name__}." + f"{fused_experts.activation_formats[0]}") + + def forward( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + ) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets + of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states: (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The topk weights applied at the end of + the layer. + - topk_ids (torch.Tensor): A map of row to expert id. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is + 1. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + + a1 = hidden_states + output = a1 if inplace else torch.zeros_like(a1) + + local_num_experts = w1.size(0) + if global_num_experts == -1: + global_num_experts = local_num_experts + + (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, + _expert_topk_weights) = self.prepare_finalize.prepare( + a1, + a1_scale, + a2_scale, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. + topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids + topk_weights = (topk_weights if _expert_topk_weights is None else + _expert_topk_weights) + + fused_out = None + + if a1q.numel() == 0: + # This happens when none of the tokens from the all2all reach this + # EP rank. Also, note that this is only relevant for CUDAGraph + # incompatible all2all kernels like the DeepEP high-throughput + # kernels. CUDAGraph compatible all2all kernels like the pplx + # kernels and the DeepEP low-latency kernels are always batched + # and can never run into the tensor.numel() == 0 case. + fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) + else: + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + if self.fused_experts.enable_chunking(): + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_chunks = cdiv(M, CHUNK_SIZE) + else: + CHUNK_SIZE = M + num_chunks = 1 + + if num_chunks == 1: + (workspace13_shape, workspace2_shape, fused_out_shape, + workspace_dtype) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, + local_num_experts) + else: + # Use the full M to get the final output shape. + _, _, fused_out_shape, _ = ( + self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, + local_num_experts)) + # Use the CHUNK_SIZE to get the workspace shapes. + workspace13_shape, workspace2_shape, _, workspace_dtype = ( + self.fused_experts.workspace_shapes( + a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts, + local_num_experts)) + + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = torch.empty(prod(workspace13_shape), + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.empty(prod(workspace2_shape), + device=a1.device, + dtype=workspace_dtype) + + if num_chunks == 1: + fused_out = _resize_cache(workspace13, fused_out_shape) + + self.fused_experts.apply( + fused_out, + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + else: + # The leading output dimension may not be equal to M, so + # we compute output indices separately. + M_out = fused_out_shape[0] + assert M_out >= M + factor = M_out // M + assert factor > 0 + OUT_CHUNK_SIZE = CHUNK_SIZE * factor + + fused_out = torch.empty(fused_out_shape, + device=a1q.device, + dtype=workspace_dtype) + + assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( + f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") + + for chunk in range(num_chunks): + begin_chunk_idx = chunk * CHUNK_SIZE + end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) + begin_out_idx = chunk * OUT_CHUNK_SIZE + end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) + curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + + self.fused_experts.apply( + fused_out[begin_out_idx:end_out_idx], + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + + self.prepare_finalize.finalize(output, fused_out, topk_weights, + topk_ids, apply_router_weight_on_input) + + return output diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py new file mode 100644 index 0000000..818ec94 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import tl, triton +from vllm.utils import cdiv, round_up + +import vllm.envs as envs +from lightop import op + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, + numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +# Triton implementation based on: +# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0 +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts, ) + tokens_cnts = torch.zeros((num_experts + 1, num_experts), + dtype=torch.int32, + device=topk_ids.device) + cumsum = torch.zeros((num_experts + 1, ), + dtype=torch.int32, + device=topk_ids.device) + tokens_per_thread = cdiv(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1, )]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + +def moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: Optional[torch.Tensor] = None, + pad_sorted_ids: bool = False, + num_token: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Note: In the case of expert_parallel, moe_align_block_size initially + considers all experts as valid and aligns all tokens appropriately. + Before the function returns it marks the experts_ids that are not in + the current GPU rank as -1 so the MoE matmuls could skip those blocks. + This requires the num_experts input arg to be the num global experts. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. + - pad_sorted_ids: A flag indicating whether the sorted_token_ids length + should be padded to a multiple of block_size, + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + if num_token: + if num_token < block_size: + max_num_tokens_padded = min(topk_ids.numel() * block_size, topk_ids.numel() + num_experts * (block_size - 1)) + else: + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.full((max_num_tokens_padded,), fill_value=topk_ids.numel(), dtype=torch.int32, device=topk_ids.device) + else: + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be zeroed out to prevent index out of bounds error while + # mapping global expert ids to local expert ids in expert parallelism. + if expert_map is not None: + expert_ids = torch.zeros((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + else: + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + if envs.VLLM_USE_LIGHT_OP: + op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad, None) + else: + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py new file mode 100644 index 0000000..d35bd00 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F +import torch_xla.experimental.custom_kernel # noqa: F401 + + +def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: + """ + Compute the histogram of a int32 tensor. The bin edges are defined by the + min and max values, with step = 1. + """ + assert input.dtype == torch.int32, "input must be of torch.int32 dtype." + assert min <= max, "min must be less than or equal to max." + + def searchsorted(sorted_sequence: torch.Tensor, + values_to_search: torch.Tensor) -> torch.Tensor: + return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) + + bin_edges = torch.linspace(min, max, max - min + 1, + dtype=input.dtype).to(input.device) + return searchsorted(bin_edges, input).to(torch.int32) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + """ + assert expert_map is None, "expert_map is not supported for pallas MoE." + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + device = hidden_states.device + dtype = hidden_states.dtype + assert (num_tokens * topk) % 16 == 0, ( + "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " + f"16 but got {num_tokens * topk}") + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, topk_indices = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + topk_indices = topk_indices.flatten() + topk_argsort_indices = topk_indices.argsort() + topk_argsort_revert_indices = topk_argsort_indices.argsort() + token_indices = torch.arange(num_tokens, + device=device).repeat_interleave(topk) + token_indices = token_indices[topk_argsort_indices] + group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) + + x = hidden_states[token_indices] + x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True) + x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] + x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True) + x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) + + x = x * topk_weights.unsqueeze(dim=-1) + x = x.sum(dim=-2) + x = x.reshape(orig_shape) + return x diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py new file mode 100644 index 0000000..20ee0d9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.size(1) + + tokens_in_chunk = curr_hidden_states.size(0) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, + apply_router_weight_on_input: bool, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.size() + K = curr_hidden.size(-1) + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + if not apply_router_weight_on_input: + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) + + +def moe_permute( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function expands and permutes activation to gather uncontinuous tokens + for each expert. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - topk_weights (torch.Tensor): topk expert route weight for each token. + - topk_ids (torch.Tensor): topk expert route id for each token. + - token_expert_indices (torch.Tensor): indice for expanded hidden. + - topk (int): The number of top-k experts to select. + - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - align_block_size (Optional[int]): align group gemm block size for deepgemm + - fill_invalid_expert(int): fill expert id in m_indices for invalid expert + to workaround DeepGemm unsupported -1 in m_indices + Returns: + - permuted_hidden_states (torch.Tensor): permuted activation. + - expert_first_token_offset (torch.Tensor): offset of the first token + of each expert for standard grouped gemm. if enable 'align_block_size' + expert_first_token_offset will align up to 'align_block_size'. + - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. + - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records + the group which the j-th row of the LHS belong to.` + """ + n_token, n_hidden = hidden_states.size() + assert (n_hidden * hidden_states.element_size() + ) % 16 == 0, "permue kernel need hidden dim align to 16B" + permuted_row_size = n_token * topk + if align_block_size is not None: + permuted_row_size = (permuted_row_size + n_expert * + (align_block_size - 1) + align_block_size - + 1) // align_block_size * align_block_size + + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + m_indices = torch.full((permuted_row_size, ), + fill_invalid_expert, + dtype=torch.int32, + device=hidden_states.device) + expert_first_token_offset = torch.empty(n_local_expert + 1, + dtype=torch.int64, + device=hidden_states.device) + src_row_id2dst_row_id_map = torch.empty((n_token, topk), + dtype=torch.int32, + device=hidden_states.device) + torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, + token_expert_indices, expert_map, n_expert, + n_local_expert, topk, align_block_size, + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices) + return (permuted_hidden_states, expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices) + + +def moe_unpermute( + permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + expert_first_token_offset: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, +) -> torch.Tensor: + """ + This function expands and permutes activation to gathering uncontinuous + tokens for each expert. + Parameters: + - permuted_hidden_states (torch.Tensor): permuted activation. + - topk_weights (torch.Tensor): topk expert route weight for each token. + - topk_ids (torch.Tensor): topk expert route id for each token. + - expert_first_token_offset (torch.Tensor): offset of the first token + of each expert for grouped gemm. + - topk (int): The number of top-k experts to select. + - n_expert (int): The number of expert. + - n_local_expert (int): The number of expert in current EP rank. + Returns: + - hidden_states (torch.Tensor): The reduced and unpermuted activation + tensor. + """ + n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) + assert (n_hidden * permuted_hidden_states.element_size() + ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" + hidden_states = torch.empty((n_token, n_hidden), + dtype=permuted_hidden_states.dtype, + device=permuted_hidden_states.device) + + torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, + topk_ids, src_row_id2dst_row_id_map, + expert_first_token_offset, n_expert, + n_local_expert, topk, hidden_states) + return hidden_states + + +def moe_permute_unpermute_supported(): + return torch.ops._moe_C.moe_permute_unpermute_supported() diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py new file mode 100644 index 0000000..6160da7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn.functional as F + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + expert_map: [num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, global_num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + if expert_map is not None: + selected_experts = expert_map[selected_experts] + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py new file mode 100644 index 0000000..112305a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import pplx_kernels as pplx +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import ( + _validate_scale_shape, moe_kernel_quantize_input) +from vllm.utils import cdiv, round_up + + +def pplx_hidden_dim_scale_bytes( + max_num_tokens: int, + hidden_dim: int, + in_dtype: torch.dtype, + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]], +): + # All pplx byte sizes must be 16-byte aligned. + align = 16 + + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to 4 * sizeof(float32) (x4 for alignment) + if quant_dtype is not None: + assert quant_dtype.itemsize == 1 + hidden_dim_bytes = hidden_dim * quant_dtype.itemsize + elem_size = torch.float32.itemsize + + if per_act_token_quant: + # per-token (M x 1) + assert block_shape is None + hidden_scale_bytes = elem_size + elif block_shape is not None: + # per-group (M x K_tiles) + block_size = block_shape[1] + num_blocks = cdiv(hidden_dim, block_size) + hidden_scale_bytes = num_blocks * elem_size + else: + # per-tensor (1 x 1) + hidden_scale_bytes = elem_size + else: + hidden_dim_bytes = hidden_dim * in_dtype.itemsize + hidden_scale_bytes = 0 + + return ( + round_up(hidden_dim_bytes, align), + round_up(hidden_scale_bytes, align), + ) + + +class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + + def __init__( + self, + a2a: pplx.AllToAll, + max_num_tokens: int, + num_local_experts: int, + num_dispatchers: int, + ): + super().__init__() + assert max_num_tokens > 0 + assert num_local_experts > 0 + self.a2a = a2a + self.max_num_tokens = max_num_tokens + self.num_local_experts = num_local_experts + self.num_dispatchers_ = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_num_tokens + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.uint32 + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + num_tokens = a1.size(0) # M + hidden_dim = a1.size(-1) # K + + assert topk_ids.size(0) == num_tokens + # assert expert_map is None, "NYI" + + # Is this always going to be a1.device? + device = a1.device + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * topk_weights.to(a1.dtype) + + repeat_cols = 4 + repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, (None if quant_config.per_act_token_quant else a1_scale), + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape) + + _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, + quant_config.block_shape) + + if a1q_scale is not None: + scalar_scales = a1q_scale.numel() == 1 + + # pplx requires 2-d scales even for scalar scales + if a1q_scale.dim() <= 1: + assert scalar_scales + a1q_scale = a1q_scale.view(1, 1) + + orig_a_scale_block_shape = a1q_scale.shape[-1] + + if not quant_config.is_block_quantized: + # TODO (bnell): use group_broadcast instead? + a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) + + assert a1q_scale is None or a1q_scale.ndim == 2, \ + f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" + + expert_num_tokens = torch.empty( + self.num_local_experts, + dtype=torch.int32, + device=device, + ) + + expert_x = torch.empty( + (self.num_local_experts, + self.max_num_tokens * self.num_dispatchers(), hidden_dim), + dtype=a1q.dtype, + device=device, + ) + + expert_x_scale: Optional[torch.Tensor] = None + if a1q.dtype.itemsize == 1: + if quant_config.is_per_act_token: + # (M x 1) -> (E x M x K) + final_dim = expert_x.size(2) + elif quant_config.is_per_tensor: + # (1 x 1) -> (E x 1 x 1) + final_dim = 1 + else: + # (M x K_tiles) -> (E x M x K_tiles) + assert quant_config.block_shape is not None + num_blocks = cdiv(expert_x.size(2), + quant_config.block_shape[1]) + final_dim = num_blocks + + expert_x_scale_shape = ( + self.num_local_experts, + expert_x.size(1), + round_up(final_dim, 4) # round up for alignment + ) + + expert_x_scale = torch.empty( + expert_x_scale_shape, + dtype=torch.float32, + device=expert_x.device, + ) + + # This argument is optional, defaults to indices.size(0) + # There's not much point setting this unless it is != indices.size(0) + bound_m: Optional[torch.Tensor] = None + + self.a2a.dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=topk_ids, + bound_m=bound_m, + ) + + if expert_x_scale is not None: + expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] + assert expert_x_scale.ndim == 3 + + return expert_x, expert_x_scale, expert_num_tokens, None, None + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + # This argument is optional + # There's not much point setting this unless it is != topk_ids.size(0) + bound_m: Optional[torch.Tensor] = None + + # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on + #num_tokens = output.size(0) # M + #assert topk_ids.size(0) == num_tokens, ( + # f"{topk_ids.size(0)} == {num_tokens}") + assert topk_ids.size() == topk_weights.size(), ( + f"{topk_ids.size()} == {topk_weights.size()}") + assert output.size(0) <= self.max_num_tokens, ( + f"{output.size(0)} <= {self.max_num_tokens}") + assert output.size(1) == fused_expert_output.size(-1) + + # Set weights to 1 if we did them in dispatch. This is hacky. + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + + self.a2a.combine(out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py new file mode 100644 index 0000000..e1114ef --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + + def num_dispatchers(self) -> int: + return 1 + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, a1_scale, quant_config.quant_dtype, + quant_config.per_act_token_quant, quant_config.block_shape) + + return a1q, a1q_scale, None, None, None + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py new file mode 100644 index 0000000..931fd6c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import IntEnum +from functools import cache +from typing import Optional + +import torch + +from vllm import envs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + + +class QuantMethod(IntEnum): + # This allows interfacing with AITER QuantType Enum + # without importing the QuantType from AITER globally. + + # Note that these quantization methods are + # supported in AITER package. However, + # not all are used in this module. + + NO = 0 # a16w16 + PER_TENSOR = 1 # w8a8 (pre_Tensor) + PER_TOKEN = 2 # w8a8/w8a4 (per_Token) + BLOCK_1X32 = 3 # fp4x2 + BLOCK_1X128 = 4 # block quantized w8a8 (per_1x128) + BLOCK_128x128 = 5 # block quantized w8a8 (per_128x128) + + +class ActivationMethod(IntEnum): + # This allows interfacing with AITER ActivationType enum + # without importing the ActivationType enum from AITER globally. + SILU = 0 + GELU = 1 + + +@cache +def is_rocm_aiter_moe_enabled() -> bool: + return False + # return current_platform.is_rocm() \ + # and envs.VLLM_ROCM_USE_AITER_MOE \ + # and envs.VLLM_ROCM_USE_AITER + + +def rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: + + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = ActivationType(activation_method) + + return asm_moe_tkw1(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation) + + +def rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + from aiter import topk_softmax + topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output, renormalize) + + +def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> None: + pass + + +def rocm_aiter_biased_grouped_topk_impl( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + + from aiter import biased_grouped_topk + + biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids, + num_expert_group, topk_group, need_renorm, + routed_scaling_factor) + + +def rocm_aiter_biased_grouped_topk_fake( + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + pass + + +def rocm_aiter_grouped_topk_impl( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + + from aiter import grouped_topk + + grouped_topk(gating_output, topk_weights, topk_ids, num_expert_group, + topk_group, need_renorm, scoring_func, routed_scaling_factor) + + +def rocm_aiter_grouped_topk_fake( + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0 # mul to topk_weights +) -> None: + pass + + +def rocm_aiter_fused_moe_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, + quant_method: int = QuantMethod.NO.value, + doweight_stage1: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + + activation = ActivationType(activation_method) + quant_type = QuantType(quant_method) + + return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, + activation, quant_type, doweight_stage1, w1_scale, + w2_scale, a1_scale, a2_scale) + + +def rocm_aiter_fused_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, + quant_method: int = QuantMethod.NO.value, + doweight_stage1: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +if current_platform.is_rocm(): + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_tkw1_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_fused_moe", + op_func=rocm_aiter_fused_moe_impl, + mutates_args=[], + fake_impl=rocm_aiter_fused_moe_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_topk_softmax", + op_func=rocm_aiter_topk_softmax_impl, + mutates_args=["topk_weights", "topk_indices", "token_expert_indices"], + fake_impl=rocm_aiter_topk_softmax_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_biased_grouped_topk", + op_func=rocm_aiter_biased_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=rocm_aiter_biased_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_grouped_topk", + op_func=rocm_aiter_grouped_topk_impl, + mutates_args=["topk_weights", "topk_ids"], + fake_impl=rocm_aiter_grouped_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def rocm_aiter_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + token = hidden_states.shape[0] + device = hidden_states.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + + if e_score_correction_bias is not None: + torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + ) + else: + assert (scoring_func == "softmax" or scoring_func == "sigmoid") + torch.ops.vllm.rocm_aiter_grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scoring_func, + ) + + return topk_weights, topk_ids + + +def rocm_aiter_fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, + expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + + activation_method = (ActivationMethod.SILU + if activation == "silu" else ActivationMethod.GELU) + # All AITER Fused MoE kernels are expecting the following datatypes + topk_weights = topk_weights.to(torch.float32) + topk_ids = topk_ids.to(torch.int32) + + if expert_map is not None: + expert_mask = (expert_map > -1).to(torch.int32) + else: + expert_mask = None + + # w8a8 per-channel quantization + if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` + # This applies topk_weights on the GEMM output of the first FC layer + # rather than the second FC. + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.shape[-1] == 1, ( + "Only support topk=1 when" + " `apply_router_weight_on_input` is True") + + return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=expert_mask, + activation_method=activation_method) + + else: + quant_method = QuantMethod.NO.value + + # w8a8 block-scaled + if block_shape is not None and use_fp8_w8a8: + assert not apply_router_weight_on_input, ( + "apply_router_weight_on_input is\ + not supported for block scaled moe") + assert w1_scale is not None + assert w2_scale is not None + quant_method = QuantMethod.BLOCK_128x128.value + elif use_fp8_w8a8: + # Currently only per tensor quantization method is enabled. + quant_method = QuantMethod.PER_TENSOR.value + + if apply_router_weight_on_input: + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + + return torch.ops.vllm.rocm_aiter_fused_moe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + expert_mask=expert_mask, + quant_method=quant_method, + activation_method=activation_method, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + doweight_stage1=apply_router_weight_on_input) + + +def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, + token_expert_indices, gating_output, + renormalize) + return topk_weights, topk_indices + + +def shuffle_weights( + *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) +) -> tuple[torch.Tensor, ...]: + """ + Applies shuffle_weight function from AITER to each + input tensor and returns them. + + Rearranges (shuffles) the input tensor/s + into a specified block layout for optimized computation. + + Args: + *tensors: Variable number of torch.Tensor objects. + layout: A pair of integers specifying the + block sizes used to divide the tensors during shuffling. + Default is (16, 16). + + Returns: + A Tuple of shuffled tensors. + """ + from aiter.ops.shuffle import shuffle_weight + + return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py new file mode 100644 index 0000000..e660376 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts + + +class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_act_token_quant: bool = False, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False, + ): + super().__init__( + FusedMoEQuantConfig.make( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + )) + self.triton_expert = TritonExperts( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + ) + self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant + and use_fp8_w8a8) + self.deep_gemm_expert = DeepGemmExperts( + ) if self.allow_deep_gemm else None + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + assert (self.deep_gemm_expert is None + or self.triton_expert.activation_formats + == self.deep_gemm_expert.activation_formats) + return self.triton_expert.activation_formats + + def supports_chunking(self) -> bool: + dge = self.deep_gemm_expert + te = self.triton_expert + return ((dge is None or dge.supports_chunking()) + and (te is None or te.supports_chunking())) + + def supports_expert_map(self) -> bool: + dge = self.deep_gemm_expert + te = self.triton_expert + return ((dge is None or dge.supports_expert_map()) + and (te is None or te.supports_expert_map())) + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + assert self.deep_gemm_expert is not None + return self.deep_gemm_expert.workspace_shapes( + a, aq, M, N, K, topk, global_num_experts, local_num_experts) + else: + return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, + global_num_experts, + local_num_experts) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ): + use_deep_gemm = (self.allow_deep_gemm + and _valid_deep_gemm(hidden_states, w1, w2)) + + experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert + assert experts is not None + + experts.apply( + output, + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py new file mode 100644 index 0000000..224830b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import prod +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +try: + from lmslim.layers.gemm.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) +except Exception: + print("INFO: Please install lmslim if you want to use int utils.\n") +from vllm.utils import cdiv + + +def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: + """ + Shrink the given tensor and apply the given view to it. This is + used to resize the intermediate fused_moe caches. + """ + assert prod(v) <= x.numel( + ), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly? + return x.flatten()[:prod(v)].view(*v) + + +def _fp8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Perform fp8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + if block_shape is None: + A, A_scale = ops.scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_act_token) + else: + assert not per_act_token + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert cdiv(A.size(-1), block_k) == A_scale.size(-1) + + return A, A_scale + + +def _int8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Perform int8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + + # If weights are per-channel (per_channel_quant=True), then + # activations apply per-token quantization. Otherwise, assume + # activation tensor-wise fp8/int8 quantization, dynamic or static + if block_shape is None: + assert per_act_token, \ + "int8 quantization only supports block or channel-wise" + A, A_scale = per_token_quant_int8(A) + else: + assert not per_act_token + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert cdiv(A.size(-1), block_k) == A_scale.size(-1) + + return A, A_scale + + +def moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + quant_dtype: Optional[torch.dtype], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if quant_dtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.int8: + return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) + else: + return A, A_scale + + +def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + A permutation routine that works on fp8 types. + """ + if torch.is_floating_point(m) and m.dtype.itemsize == 1: + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + else: + return m[idx, ...] + + +def normalize_scales_shape( + scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + scales = scales.view(1, 1) + else: + scales = scales.view(-1, scales.size(-1)) + return scales + + +def normalize_batched_scales_shape( + scales: Optional[torch.Tensor], + num_experts: int, +) -> Optional[torch.Tensor]: + if scales is not None and scales.ndim < 3: + if scales.numel() == 1: + scales = scales.view(1) + scales = torch.repeat_interleave(scales, num_experts, + dim=0).view(num_experts, 1, 1) + else: + scales = scales.view(num_experts, -1, scales.size(-1)) + + return scales + + +def _validate_scale_shape( + a: torch.Tensor, + a_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]], +) -> None: + if a_scale is None: + return + + if not per_act_token_quant and block_shape is None: + assert a_scale.numel() == 1, f"{a_scale.shape}" + elif per_act_token_quant: + assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( + f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") + else: + assert block_shape is not None + expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) + assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py new file mode 100644 index 0000000..492fb58 --- /dev/null +++ b/vllm/model_executor/layers/layernorm.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Custom normalization layers.""" +from typing import Optional, Union, Tuple + +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.model_executor.custom_op import CustomOp + +from vllm.platforms import current_platform + + +def is_rocm_aiter_rmsnorm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + and envs.VLLM_ROCM_USE_AITER + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) + if envs.VLLM_USE_OPT_OP: + ops.rms_norm_opt( + out, + x, + weight, + variance_epsilon, + ) + else: + ops.rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + + +def fused_add_rms_norm( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + if envs.VLLM_USE_OPT_OP: + ops.fused_add_rms_norm_opt( + x, + residual, + weight, + variance_epsilon, + ) + else: + ops.fused_add_rms_norm( + x, + residual, + weight, + variance_epsilon, + ) + return x, residual + + +def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + import aiter as rocm_aiter + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rocm_aiter.rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + + return rocm_aiter.rms_norm(x, weight, variance_epsilon) + + +def rocm_aiter_fused_add_rms_norm( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + + import aiter as rocm_aiter + + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) + rocm_aiter.rmsnorm2d_fwd_with_add( + output, # output + x, # input + residual, # residual input + residual_out, # residual output + weight, + variance_epsilon, + ) + return output, residual_out + + +def dispatch_cuda_rmsnorm_func(add_residual: bool): + if add_residual: + if is_rocm_aiter_rmsnorm_enabled(): + return rocm_aiter_fused_add_rms_norm + return fused_add_rms_norm + + if is_rocm_aiter_rmsnorm_enabled(): + return rocm_aiter_rms_norm + return rms_norm + + +@CustomOp.register("rms_norm") +class RMSNorm(CustomOp): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.variance_size_override = (None if var_hidden_size == hidden_size + else var_hidden_size) + self.has_weight = has_weight + if dtype is not None: + self.weight = torch.ones(hidden_size, dtype=dtype) + else: + self.weight = torch.ones(hidden_size) + if self.has_weight: + self.weight = nn.Parameter(self.weight) + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError("Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}") + + if self.variance_size_override is None: + x_var = x + else: + if hidden_size < self.variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{self.variance_size_override}, but found: {hidden_size}") + + x_var = x[:, :, :self.variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) + if self.has_weight: + x = x * self.weight + if residual is None: + return x + else: + return x, residual + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + add_residual = residual is not None + norm_func = dispatch_cuda_rmsnorm_func(add_residual) + + if add_residual: + return norm_func(x, residual, self.weight.data, + self.variance_epsilon) + else: + return norm_func(x, self.weight.data, self.variance_epsilon) + + def forward_apex( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + add_residual = residual is not None + norm_func = dispatch_cuda_rmsnorm_func(add_residual) + + if add_residual: + return norm_func(x, residual, self.weight.data, + self.variance_epsilon) + else: + return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon) + + def forward_hpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + from vllm_hpu_extension.kernels import rms_norm + HPUFusedRMSNorm = rms_norm() + if HPUFusedRMSNorm is None: + return self.forward_native(x, residual) + if residual is not None: + orig_shape = x.shape + residual += x.view(residual.shape) + # Note: HPUFusedRMSNorm requires 3D tensors as inputs + x = HPUFusedRMSNorm.apply(residual, self.weight, + self.variance_epsilon) + return x.view(orig_shape), residual + + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x + + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + from vllm._ipex_ops import ipex_ops as ops + + if residual is not None: + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual + return ops.rms_norm( + x, + self.weight.data, + self.variance_epsilon, + ) + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +@CustomOp.register("gemma_rms_norm") +class GemmaRMSNorm(CustomOp): + """RMS normalization for Gemma. + + Two differences from the above RMSNorm: + 1. x * (1 + w) instead of x * w. + 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, + x: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + if residual is not None: + if orig_dtype == torch.float16: + x = x + residual.float() + else: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + x = x * (1.0 + weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if torch.compiler.is_compiling(): + return self.forward_native(x, residual) + + if not getattr(self, "_is_compiled", False): + self.forward_static = torch.compile( # type: ignore + self.forward_static) + self._is_compiled = True + return self.forward_native(x, residual) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py new file mode 100644 index 0000000..978086d --- /dev/null +++ b/vllm/model_executor/layers/lightning_attn.py @@ -0,0 +1,652 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from einops import rearrange + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, + d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, + NUM_BLOCK, CBLOCK: tl.constexpr): + # This kernel computes the diagonal blocks of the attention matrix + # Each diagonal block represents attention + # where queries attend to keys in the same block + off = tl.program_id(0) + off_bh = off // NUM_BLOCK # batch-head index + off_block = off % NUM_BLOCK # block index within the sequence + off_cblock = tl.program_id(1) # sub-block index within a block + + off_h = off_bh % h # head index + + # Calculate base offsets for the current batch and head + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + + # Calculate offsets for the current block + block_offset = off_block * BLOCK + qk_block_offset = block_offset * d + v_block_offset = block_offset * e + o_block_offset = block_offset * e + + # Calculate offsets for the current sub-block + cblock_offset = off_cblock * CBLOCK + q_cblock_offset = cblock_offset * d + o_cblock_offset = cblock_offset * e + + # Calculate pointers to the query, key, value, and output tensors + Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :]) + K_trans_block_ptr = (K + qk_offset + qk_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, d)[:, None]) + V_block_ptr = (V + v_offset + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :]) + O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :]) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + i = off_cblock + q_index = tl.arange(0, CBLOCK) + i * CBLOCK + + # Load query values + q = tl.load(Q_block_ptr, + mask=block_offset + q_index[:, None] < n, + other=0.0).to(tl.float32) + + # Initialize output accumulator + qkv = tl.zeros([CBLOCK, e], dtype=tl.float32) + + # Process all sub-blocks up to and + # including the current one (causal attention) + for j in range(i + 1): + kv_index = tl.arange(0, CBLOCK) + j * CBLOCK + diff = q_index[:, None] - kv_index[None, :] + s_index = s * diff + # Apply causal mask: only attend to positions before the current one + s_index = tl.where(diff >= 0, -s_index, float("-inf")) + decay = tl.exp(s_index) + + # Load key and value + k_trans = tl.load( + K_trans_block_ptr, + mask=block_offset + kv_index[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr, + mask=block_offset + kv_index[:, None] < n, + other=0.0, + ).to(tl.float32) + + # Compute attention scores and apply decay + qk = tl.dot(q, k_trans) * decay + + # Compute weighted values and accumulate + qkv += tl.dot(qk, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + + # Store the result + tl.store( + O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=block_offset + q_index[:, None] < n, + ) + + +@triton.jit +def _fwd_kv_parallel( + K, + V, + K_decay, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, + NUM_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the key-value outer + # products for each block in parallel + off_bh = tl.program_id(0) # batch-head index + off_block = tl.program_id(1) # block index + + off_h = off_bh % h # head index + + block_offset = off_block * BLOCK + + # Calculate offsets for the current block + k_block_offset = block_offset * d + v_block_offset = block_offset * e + kv_block_offset = off_block * d * e + + # Calculate base offsets for the current batch and head + k_offset = off_bh * n * d + v_offset = off_bh * n * e + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointers to the key, value, and key-value tensors + K_trans_block_ptr = (K + k_offset + k_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, D_FBLOCK)[:, None]) + V_block_ptr = (V + v_offset + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = (KV + kv_offset + kv_block_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay factors for the current head and block + k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]) + + kv_index = tl.arange(0, CBLOCK) + + # Initialize the key-value outer product accumulator + kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) + + # Handle the last block which might be smaller than BLOCK + if off_block == NUM_BLOCK - 1: + split_n = n - (NUM_BLOCK - 1) * BLOCK + else: + split_n = BLOCK + left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n + num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) + k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK + + # Process all sub-blocks in the current block + for j in range(num_blocks): + left_bound = (1 - j) * left_shift + # Load key and value, handling boundary conditions + k_trans = tl.load(K_trans_block_ptr - left_shift * d, + mask=kv_index[None, :] >= left_bound, + other=0.0) + v = tl.load(V_block_ptr - left_shift * e, + mask=kv_index[:, None] >= left_bound, + other=0.0) + + # Load decay factor and compute weighted key-value outer product + k_decay = tl.load(k_decay_ptr) + kv += tl.dot(k_trans * k_decay, v) + + # Move to the next sub-block + K_trans_block_ptr += CBLOCK * d + V_block_ptr += CBLOCK * e + k_decay_ptr += CBLOCK + + # Store the result + tl.store(KV_block_ptr, kv.to(KV_block_ptr.dtype.element_ty)) + + +@triton.jit +def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, + d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, + NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr): + # This kernel reduces the key-value outer products + # across blocks and updates the KV history + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + kv_offset = off_bh * NUM_BLOCK * d * e + + # Calculate pointer to the key-value tensor + KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay rate for the current head + s_ptrs = S + off_h + s = tl.load(s_ptrs) + + # Calculate pointer to the key-value history tensor + kv_history_offset = off_bh * d * e + KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the previous key-value history + kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32) + + # Process all blocks in reverse order to compute the prefix sum + for i in range(NUM_BLOCK): + block_size = min(n - i * BLOCK, BLOCK) + # Compute decay factor for the current block + block_decay = tl.exp(-s.to(tl.float32) * block_size) + + # Load the current key-value outer product + kv_cur = tl.load(KV_block_ptr).to(tl.float32) + # Store the previous key-value history to the current block + tl.store(KV_block_ptr, kv_pre.to(KV_block_ptr.dtype.element_ty)) + + # Update the key-value history with the current block + kv_pre = block_decay * kv_pre + kv_cur + KV_block_ptr += d * e + + # Store the updated key-value history + tl.store(KV_HISTORY_block_ptr, kv_pre) + + +@triton.jit +def _fwd_none_diag_kernel( + Q, + Out, + S, + KV, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + E_FBLOCK: tl.constexpr, + CBLOCK: tl.constexpr, + NUM_CBLOCK: tl.constexpr, +): + # This kernel computes the non-diagonal blocks of the attention matrix + # Each non-diagonal block represents attention + # where queries attend to keys in different blocks + off_bh = tl.program_id(0) # batch-head index + off_h = off_bh % h # head index + + off_nc = tl.program_id(1) + off_n = off_nc // NUM_CBLOCK # block index + off_c = off_nc % NUM_CBLOCK # sub-block index + off_e = tl.program_id(2) # output feature block index + + n_offset = off_n * BLOCK + c_offset = off_c * CBLOCK + e_offset = off_e * E_FBLOCK + block_offset = n_offset + c_offset + + # Calculate offsets for the current batch, head, and block + q_offset = off_bh * n * d + (n_offset + c_offset) * d + o_offset = off_bh * n * e + (n_offset + c_offset) * e + e_offset + kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset + + # Calculate pointers to the query, output, and key-value tensors + Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :]) + O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :]) + + # Load the decay rate for the current head + S_block_ptr = S + off_h + s = tl.load(S_block_ptr) + + c_array = tl.arange(0, CBLOCK) + + # Load the key-value outer product for the current block + kv = tl.load(KV_block_ptr).to(tl.float32) + q_index = block_offset + tl.arange(0, CBLOCK) + + # Load query values + q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, + other=0.).to(tl.float32) + + # Compute decay factors for the current sub-block + q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None])) + + # Compute non-diagonal attention output + qkv_none_diag = tl.dot(q, kv) * q_decay + + # Load diagonal attention output (computed by _fwd_diag_kernel) + qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, + other=0.).to(tl.float32) + + # Combine diagonal and non-diagonal attention outputs + qkv = qkv_diag + qkv_none_diag + + # Store the result + tl.store(O_block_ptr, + qkv.to(O_block_ptr.dtype.element_ty), + mask=q_index[:, None] < n) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, s, kv_history): + # Forward pass of the lightning attention algorithm + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + # Check CUDA compute capability + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported", + "for compute capability >= 80") + + # Get input dimensions + b, h, n, d = q.shape + e = v.shape[-1] + + # Initialize output tensor + o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + + # Set block sizes + BLOCK = 256 + NUM_BLOCK = triton.cdiv(n, BLOCK) + + CBLOCK = 32 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Compute decay factors for keys + array = torch.arange(0, BLOCK, device=q.device) + 1 + k_decay = torch.exp(-s * (BLOCK - array.reshape(1, -1))) + + # Step 1: Compute diagonal blocks of attention + grid = (b * h * NUM_BLOCK, NUM_CBLOCK) + _fwd_diag_kernel[grid](q, + k, + v, + o, + s, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + CBLOCK=CBLOCK) + + # Set feature block sizes + NUM_FBLOCK = 1 + D_FBLOCK = d // NUM_FBLOCK + assert d % NUM_FBLOCK == 0 + E_FBLOCK = e // NUM_FBLOCK + assert e % NUM_FBLOCK == 0 + + CBLOCK = 64 + NUM_CBLOCK = BLOCK // CBLOCK + assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" + + # Step 2: Compute key-value outer products for each block in parallel + kv = torch.empty((b, h, NUM_BLOCK, d, e), + dtype=torch.float32, + device=q.device) + grid = (b * h, NUM_BLOCK) + _fwd_kv_parallel[grid]( + k, + v, + k_decay, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + NUM_FBLOCK=NUM_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Step 3: Reduce key-value outer products + # across blocks and update KV history + grid = (b * h, NUM_FBLOCK) + _fwd_kv_reduce[grid](s, + kv, + kv_history, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK) + + # Step 4: Compute non-diagonal blocks of attention + grid = (b * h, NUM_BLOCK * NUM_CBLOCK) + _fwd_none_diag_kernel[grid]( + q, + o, + s, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + E_FBLOCK=E_FBLOCK, + CBLOCK=CBLOCK, + NUM_CBLOCK=NUM_CBLOCK, + ) + + # Save tensors for backward pass + ctx.save_for_backward(q, k, v, s, kv) + ctx.BLOCK = BLOCK + + return o, torch.cat([kv, kv_history.unsqueeze(2)], dim=2) + + +# Apply the lightning attention function +lightning_attention_ = _attention.apply + + +def lightning_attention(q, k, v, ed, block_size=256, kv_history=None): + """ + Apply lightning attention algorithm + to compute attention efficiently. + + Args: + q: Query tensor of shape [batch, heads, seq_len, dim] + k: Key tensor of shape [batch, heads, seq_len, dim] + v: Value tensor of shape [batch, heads, seq_len, dim_v] + ed: Decay rate tensor of shape [heads] + block_size: Size of blocks for block-sparse attention + kv_history: Optional key-value history from previous computations + + Returns: + output: Attention output + kv: Updated key-value history + """ + d = q.shape[-1] + e = v.shape[-1] + + if ed.dim() == 1: + ed = ed.view(1, -1, 1, 1) + + # Split the computation into chunks for better parallelism + m = 128 if d >= 128 else 64 + assert d % m == 0, f"Dimension d ({d}) must be divisible by m ({m})" + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + output = 0 + + # Initialize or clone key-value history + if kv_history is None: + kv_history = torch.zeros((q.shape[0], q.shape[1], d, e), + dtype=torch.float32, + device=q.device) + else: + kv_history = kv_history.clone().contiguous() + + # Process each chunk and accumulate results + for i in range(n - 1): + s = arr[i] + e = arr[i + 1] + q1 = q[..., s:e] + k1 = k[..., s:e] + o, kv = lightning_attention_(q1, k1, v, ed, kv_history) + output = output + o + return output, kv + + +@triton.jit +def _linear_attn_decode_kernel( + q_ptr, + k_ptr, + v_ptr, + kv_cache_ptr, + slope_rate, + slot_idx, + output_ptr, + D: tl.constexpr, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for linear attention decoding with KV cache. + + This kernel computes attention for a single token using the KV cache. + """ + pid_b = tl.program_id(0) # batch index + pid_h = tl.program_id(1) # head index + pid_d = tl.program_id(2) # dimension block index + + # Load slot index for the current batch + slot_id = tl.load(slot_idx + pid_b) + + # Skip if slot_id is -1 (padding) + if slot_id == -1: + return + + batch_id = pid_b + head_id = pid_h + + # Load decay rate for the current head + ratio = tl.load(slope_rate + pid_h) + + # Calculate offsets for dimensions + qk_d_offsets = tl.arange(0, D) + v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE + cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[ + None, :] * cache_d1_stride + + # Calculate offsets for the current batch and head + q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + + cache_offset = slot_id * cache_b_stride + head_id * cache_h_stride + + # Create masks for loading tensors + qk_mask = qk_d_offsets < D + v_mask = v_d_offsets < D + + # Load query, key, and value tensors + q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0) + k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0) + v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0) + + # Compute key-value outer product + kv_outer = k[:, None] * v[None, :] + kv_mask = qk_mask[:, None] & v_mask[None, :] + + # Apply decay to previous KV cache + ratio = tl.exp(-ratio) + kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets + kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0) + kv_outer = kv_outer + ratio * kv_cache_old + + # Compute attention output + output = q[:, None].to(tl.float32) * kv_outer + output = tl.sum(output, axis=0) + + # Update KV cache and store output + tl.store(kv_ptr, kv_outer, mask=kv_mask) + tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask) + + +def linear_decode_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + slot_idx: torch.Tensor, + BLOCK_SIZE: int = 32, +) -> torch.Tensor: + """ + Perform linear attention decoding using Triton kernels. + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, D] + kv_caches: Key-value cache tensor + slope_rate: Decay rate tensor + slot_idx: Slot indices for batches + BLOCK_SIZE: Size of blocks for processing + + Returns: + output: Attention output tensor + """ + B, H, _, D = q.shape + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, D) + + # Initialize output tensor + output = torch.empty_like(q) + + # Set grid dimensions for the kernel + grid = (B, H, D // BLOCK_SIZE) + + # Calculate strides for tensors + qkv_b_stride = q.stride(0) + qkv_h_stride = q.stride(1) + + cache_b_stride = kv_caches.stride(0) + cache_h_stride = kv_caches.stride(1) + cache_d0_stride = kv_caches.stride(2) + cache_d1_stride = kv_caches.stride(3) + + # Launch the kernel + _linear_attn_decode_kernel[grid]( + q, + k, + v, + kv_caches, + slope_rate, + slot_idx, + output, + D, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d0_stride, + cache_d1_stride, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Reshape output and return + output = rearrange(output, "b h n d -> b n (h d)") + return output.squeeze(1).contiguous() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py new file mode 100644 index 0000000..d706066 --- /dev/null +++ b/vllm/model_executor/layers/linear.py @@ -0,0 +1,1744 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +from abc import abstractmethod +from typing import Any, Literal, Optional, Union +import vllm.envs as envs +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm import envs +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter) +# yapf: enable +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +import os +from vllm.model_executor.utils import gemm_bank_conf + +if envs.USE_FUSED_RMS_QUANT: + try: + from lmslim.quantize.quant_ops import lm_faster_rmsquant + except Exception as e: + print(f"Error: Import fused rmsquant error: {e}") + +logger = init_logger(__name__) + +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", + "BitBLASLinearMethod", + "GPTQBitBLASLinearMethod", + "AWQMarlinLinearMethod", + "AWQLinearMethod", + "GPTQMarlinLinearMethod", + "Fp8LinearMethod", + "MarlinLinearMethod", + "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", + "IPEXGPTQLinearMethod", + "HQQMarlinMethod", + "QuarkLinearMethod", + "ModelOptNvFp4LinearMethod", + "BlockInt8LinearMethod", +] + + +def adjust_bitblas_shard(param, shard_size, shard_offset): + bitblas_tile_size = getattr(param, "bitblas_tile_size", None) + if bitblas_tile_size is not None: + return (shard_size // bitblas_tile_size, + shard_offset // bitblas_tile_size) + + return shard_size, shard_offset + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_4bit_shard(param: Parameter, + shard_offsets: dict[str, tuple[int, int]], + loaded_shard_id: str) -> tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = shard_offsets["total"] + orig_offset, orig_size = shard_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + if envs.VLLM_USE_NN: + return param[shard_id], loaded_weight.t() + else: + return param[shard_id], loaded_weight + + +# TODO(Isotr0py): We might need a more flexible structure to handle +# bitsandbytes shard offsets. +def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): + """ + Separate the BitsAndBytes 4-bit shard. + + For example, given bnb weight attributes as below: + { + 'bnb_shard_offsets': array([0, 4, 8, 16]), + 'bnb_quant_state': {0: ..., 1: ..., 2: ...}, + } + + The function will return: + { + 'bnb_shard_offsets': array([0, 4]), + 'bnb_quant_state': {0: ...}, + } + and + { + 'bnb_shard_offsets': array([0, 4, 12]), + 'bnb_quant_state': {0: ..., 1: ...}, + } + """ + shard_offsets = bnb_weight_attrs["bnb_shard_offsets"] + offset_l = shard_offsets[:2] + offset_r = shard_offsets[1:] - shard_offsets[1] + quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]} + quant_state_r = { + i - 1: bnb_weight_attrs["bnb_quant_state"][i] + for i in range(1, + len(shard_offsets) - 1) + } + left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l) + right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r) + return left, right + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def __init__(self): + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + if envs.VLLM_USE_NN: + weight = Parameter(torch.empty(input_size_per_partition, + sum(output_partition_sizes), + dtype=params_dtype), + requires_grad=False) + else: + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if (torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16 and N % 32 == 0 + and K % 32 == 0): + packed_weight = torch.ops._C.convert_weight_packed( + layer.weight) + assert packed_weight.size() == layer.weight.size() + layer.weight.copy_(packed_weight) + if layer.bias is not None: + layer.bias = Parameter(layer.bias.to(torch.float32), + requires_grad=False) + layer.use_cpu_sgl = True + else: + logger.warning( + "CPU SGL kernels require Intel AMX support," + " bfloat16 weight, IC and OC are divisible by 32.") + layer.use_cpu_sgl = False + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_llama_nn: + if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1': + layer.weight = layer.weight[:,:-32] + + if bias is not None: + if len(x.shape) == 2: + return torch.addmm(bias, x, layer.weight) + else: + return torch.matmul(x, layer.weight) + bias + else: + return torch.matmul(x, layer.weight) + else: + if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: + return dispatch_unquantized_gemm()(x, layer.weight.t(), bias) + else: + return dispatch_unquantized_gemm()(x, layer.weight, bias) + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) + self.return_bias = return_bias + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + eps: Optional[float] = 1e-6, + prefix: str = "", + *, + return_bias: bool = True, + ): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias) + self.eps = eps + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + # Special case for GGUF + + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) + if envs.VLLM_USE_NN and not is_quantization: + loaded_weight = loaded_weight.t() + + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}") + param.data.copy_(loaded_weight) + + def forward( + self, + input_: torch.Tensor, + rms_weight: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + quant_args: Optional[list] = None, + update_hd: Optional[bool] = True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None): + if quant_args is not None: + input_quant_args = quant_args + + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, input_, bias, input_quant_args) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + else: + i_q, _scales = lm_faster_rmsquant(input=input_, + rms_weight=rms_weight, + epsilon=self.eps, + quant_dtype=torch.int8, + residual=residual, + update_input=update_hd + ) + + new_residual = residual + input_quant_args = [i_q, _scales] + + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, input_, bias, input_quant_args) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, new_residual, output_bias, input_quant_args + + else: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, input_, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + eps: Optional[float] = 1e-6, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + self.eps = eps + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + final_shape = list(loaded_weight.shape) + if output_dim is not None: + tp_size = get_tensor_model_parallel_world_size() + assert final_shape[output_dim] % tp_size == 0 + final_shape[output_dim] = final_shape[output_dim] // tp_size + param.materialize(final_shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization: + shard_size = param_data.shape[output_dim] + else: + shard_size = param_data.shape[int(not(output_dim))] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + if envs.VLLM_USE_NN and not is_quantization: + loaded_weight = loaded_weight.t() + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward( + self, input_, + rms_weight: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + update_hd: Optional[bool] = True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: + input_quant_args = None + assert rms_weight is not None + i_q, _scales = lm_faster_rmsquant(input=input_, + rms_weight=rms_weight, + epsilon=self.eps, + quant_dtype=torch.int8, + residual=residual, + update_input=update_hd) + new_residual = residual + input_quant_args = [i_q, _scales] + + bias = self.bias if not self.skip_bias_add else None + + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args) + if self.gather_output: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, new_residual, output_bias + else: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + """ + + def forward( + self, input_, + rms_weight: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + update_hd: Optional[bool] = True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: + input_quant_args = None + assert residual is not None and rms_weight is not None + i_q, _scales = lm_faster_rmsquant(input=input_, + rms_weight=rms_weight, + epsilon=self.eps, + quant_dtype=torch.int8, + residual=residual, + update_input=update_hd) + + new_residual = residual + input_quant_args = [i_q, _scales] + + + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args) + + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, new_residual, output_bias + else: # not USE_FUSED_RMS_QUANT + bias = self.bias if not self.skip_bias_add else None + + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + eps: Optional[float] = 1e-6, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.eps = eps + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + if loaded_shard_id is not None: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.shard_weight_type = { + i: loaded_weight.item() + for i, _ in enumerate(self.output_sizes) + } + return + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + + if use_bitsandbytes_4bit: + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) + for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(shard_id)) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + if use_bitsandbytes_4bit: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + + if not envs.VLLM_USE_NN or is_quantization: + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + else: + param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size) + + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + if envs.VLLM_USE_NN and not is_quantization: + loaded_weight = loaded_weight.t() + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_tensor_model_parallel_world_size() + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + + from vllm.model_executor.layers.quantization.blockwise_int8 import ( + BlockInt8LinearMethod, BlockInt8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) // tp_size + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias) + + def _get_shard_offset_mapping(self, loaded_shard_id: str): + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + # Note(simon): This is needed for Qwen3's fp8 quantization. + if isinstance(param, BlockQuantScaleParameter): + assert self.quant_method is not None + assert hasattr(self.quant_method, "quant_config") + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + + param.load_qkv_weight(loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + idx_map = {"q": 0, "k": 1, "v": 2} + if loaded_shard_id is not None: + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.shard_weight_type = { + k: loaded_weight.item() + for k in idx_map + } + return + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.total_num_heads * self.head_size), + "k": (self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + "v": + ((self.total_num_heads + self.total_num_kv_heads) * + self.head_size, + self.total_num_kv_heads * self.head_size), + "total": + ((self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_size, 0) + } + + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, shard_id) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) + } + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, loaded_shard_id) + + if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization: + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + else: + param_data = param_data.narrow(int(not(output_dim)), shard_offset, + shard_size) + + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + + if envs.VLLM_USE_NN and not is_quantization: + loaded_weight = loaded_weight.t() + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + reduce_results: If true, call all-reduce on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y = X_iA_i + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.down_proj) + return_bias: If true, return bias together with outputs in forward pass. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce + self.tbo_all_reduce = tbo_all_reduce + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) + + is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod) + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + if not envs.VLLM_USE_NN or is_quantization: + shard_size = param_data.shape[input_dim] + else: + shard_size = param_data.shape[int(not(input_dim))] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + if envs.VLLM_USE_NN and not is_quantization: + loaded_weight = loaded_weight.t() + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + if envs.VLLM_ENABLE_TBO: + output = self.tbo_all_reduce(output_parallel) + else: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s + + +class QKVCrossParallelLinear(LinearBase): + """Linear layers for efficient cross-attention's QKV transformation. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + # input_size and output_size are not used, just for alignment + input_size = hidden_size + output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size + super().__init__(input_size=input_size, + output_size=output_size, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + + self.quant_config = quant_config + + # Empty placeholders for loading as a single module. + placeholder_size = 0 + assert self.quant_method is not None + self.quant_method.create_weights(self, + placeholder_size, [placeholder_size], + placeholder_size, + placeholder_size, + self.params_dtype, + weight_loader=self.weight_loader) + + # Use a dictionary to avoid submodules parameters auto-registration: + # drop-in replacement for a `QKVParallelLinear` module. + self.proj = dict() + self.proj["q_proj_decoder"] = ColumnParallelLinear( + input_size=hidden_size, + output_size=total_num_heads * head_size, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.q_proj_decoder") + + self.proj["kv_proj_encoder"] = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_size, + total_num_heads=0, + total_num_kv_heads=total_num_kv_heads, + bias=bias, + quant_config=quant_config, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + prefix=f"{prefix}.kv_proj_encoder") + + # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. + self.q_size = self.q_proj_decoder.output_size_per_partition + self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size + + if bias: + self.bias = torch.nn.Parameter() + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.bias = None + + def process_weights_after_loading(self): + for layer in self.proj.values(): + if self.quant_method is not None: + self.quant_method.process_weights_after_loading(layer) + + @property + def q_proj_decoder(self) -> ColumnParallelLinear: + layer = self.proj["q_proj_decoder"] + for name, param in self.named_parameters(): + target_param = getattr(layer, name, None) + if target_param is not None: + self.sync_weight_attrs(param, + target_param, + mode="q_proj_decoder") + return layer + + @property + def kv_proj_encoder(self) -> QKVParallelLinear: + layer = self.proj["kv_proj_encoder"] + for name, param in self.named_parameters(): + target_param = getattr(layer, name, None) + if target_param is not None: + self.sync_weight_attrs(param, + target_param, + mode="kv_proj_encoder") + return layer + + def sync_weight_attrs( + self, + src_param: nn.Parameter, + tgt_param: nn.Parameter, + mode: Literal["q_proj_decoder", "kv_proj_encoder"], + ): + missing_attrs_dict = { + k: getattr(src_param, k) + for k in (set(vars(src_param).keys()) - + set(vars(tgt_param).keys())) + } + # TODO(Isotr0py): handle bitsandbytes 8bit + use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", + False) + if (missing_attrs_dict and use_bitsandbytes_4bit): + q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( + missing_attrs_dict) + if mode == "q_proj_decoder": + set_weight_attrs(tgt_param, q_proj_attrs) + elif mode == "kv_proj_encoder": + set_weight_attrs(tgt_param, kv_proj_attrs) + else: + set_weight_attrs(tgt_param, missing_attrs_dict) + + def _is_same_param( + self, + src_param: torch.nn.Parameter, + map_param: torch.nn.Parameter, + ) -> bool: + """Check if two parameters are exactly pointing to same things.""" + # ignore weight_loader because it's always different + key_to_ignore = ["weight_loader", "_weight_loader"] + has_same_type_name = type(src_param) is type(map_param) + src_param_attrs = { + k: v + for k, v in src_param.__dict__.items() if k not in key_to_ignore + } + map_param_attrs = { + k: v + for k, v in map_param.__dict__.items() if k not in key_to_ignore + } + has_same_attrs = src_param_attrs == map_param_attrs + return has_same_type_name and has_same_attrs + + def select_proj_params( + self, + layer: nn.Module, + param: nn.Parameter, + ) -> nn.Parameter: + """ + Given the placeholder param, + return the corresponding param in the proj layers. + """ + target_param_list = [ + v for _, v in layer.named_parameters() + if self._is_same_param(param, v) + ] + assert len(target_param_list) == 1 + target_param = target_param_list[0] + return target_param + + def forward( # type: ignore[override] + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + q, _ = self.q_proj_decoder(decoder_hidden_states) + if encoder_hidden_states is None: + # Encoder KV already cached. + k = None + v = None + else: + # Prefill phase, encoder KV cached here. + kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states) + # Split kv in half + k, v = kv_enc.split(self.kv_size, dim=-1) + return q, k, v + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + layer = (self.q_proj_decoder + if loaded_shard_id == "q" else self.kv_proj_encoder) + target_param = self.select_proj_params(layer, param) + shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: + layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) + else: + layer.weight_loader(target_param, loaded_weight, *shard_id_args) + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", q_size={self.q_size}" + s += f", kv_size={self.kv_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += ", gather_output=False" + return s \ No newline at end of file diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py new file mode 100644 index 0000000..3d01253 --- /dev/null +++ b/vllm/model_executor/layers/logits_processor.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A layer that compute logits from hidden_stats.""" +import inspect +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform + +_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None +if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: + _logits_processor_threadpool = ThreadPoolExecutor( + envs.VLLM_LOGITS_PROCESSOR_THREADS) + + +class LogitsProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_all_gather = current_platform.use_all_gather() + + def forward( + self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + sampling_metadata: Optional[SamplingMetadata] = None, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + if sampling_metadata is not None: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, lm_head, embedding_bias) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: + """gather/all-gather the logits tensor across model parallel group.""" + if self.use_all_gather: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + else: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + return logits + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + + # Gather logits for TP + logits = self._gather_logits(logits) + + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits + + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios + # (warmup, profile_run) we might not have selected_token_indices, + # so we skip pruning. + if sampling_metadata.selected_token_indices is not None: + return hidden_states.index_select( + 0, sampling_metadata.selected_token_indices) + else: + return hidden_states + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + found_logits_processors = False + logits_processed = 0 + logits_row_ids_and_logits_row_futures = [] + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): + logits_row = logits[logits_row_idx] + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + + if _logits_processor_threadpool is not None: + logits_row_ids_and_logits_row_futures.append( + (logits_row_idx, + _logits_processor_threadpool.submit( + _apply_logits_processors_single_seq, logits_row, + logits_processors, past_tokens_ids, + prompt_tokens_ids))) + else: + logits[logits_row_idx] = \ + _apply_logits_processors_single_seq( + logits_row, logits_processors, past_tokens_ids, + prompt_tokens_ids) + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + + for logits_row_idx, future in logits_row_ids_and_logits_row_futures: + logits[logits_row_idx] = future.result() + + if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly + assert logits_processed == logits.shape[0] + return logits + + +def _apply_logits_processors_single_seq(logits_row, logits_processors, + past_tokens_ids, + prompt_tokens_ids) -> torch.Tensor: + for logits_processor in logits_processors: + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, logits_row) + return logits_row diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py new file mode 100644 index 0000000..88053fa --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from dataclasses import dataclass + +import torch + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionMetadata) +from vllm.platforms import current_platform + + +@dataclass +class Mamba2Metadata: + + has_initial_states: torch.Tensor + prep_initial_states: bool + + chunk_size: int + seq_idx: torch.Tensor + chunk_indices: torch.Tensor + chunk_offsets: torch.Tensor + + +def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: + """Returns the appropriate metadata classes for the current platform.""" + if current_platform.is_rocm(): + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata) + return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) + elif current_platform.is_cuda(): + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + from vllm.attention.backends.xformers import XFormersMetadata + return (FlashAttentionMetadata, XFormersMetadata, + PlaceholderAttentionMetadata) + raise ValueError( + f"Unsupported platform for Mamba2: {current_platform.device_type}") + + +def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, + chunk_size: int, + total_seqlens: int): + + cu_seqlens = query_start_loc[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + +def prepare_mamba2_metadata( + chunk_size: int, + attn_metadata: AttentionMetadata, +) -> Mamba2Metadata: + + # compute number of prefill and decode requests + # NOTE: in V0 we assume prefills are before decodes + num_prefills = attn_metadata.num_prefills + num_prefill_tokens = attn_metadata.num_prefill_tokens + + seq_idx = None + chunk_indices, chunk_offsets = None, None + # Need flags to indicate if there are initial states + # currently we really only support the FlashAttention backend + has_initial_states = None + prep_initial_states = False + + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + if num_prefills > 0: + attn_metadata_instances = get_platform_metadata_classes() + if (isinstance(attn_metadata, attn_metadata_instances) + and attn_metadata.context_lens_tensor is not None): + has_initial_states = \ + attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] + # precompute flag to avoid device syncs in mamba2 layer forwards + # prep is only needed for mamba2 ssd prefill processing + prep_initial_states = torch.any(has_initial_states).item() + + query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] + seq_idx = torch.repeat_interleave(torch.arange( + num_prefills, dtype=torch.int32, device=query_start_loc.device), + query_start_loc.diff(), + output_size=num_prefill_tokens) + seq_idx.unsqueeze_(0) + + # We compute metadata for chunked prefill once at the top level model + # forward and reuse them in mamba layers. If not needed, they will be + # ignored inside mamba kernels. + if prep_initial_states: + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + query_start_loc, chunk_size, num_prefill_tokens) + + return Mamba2Metadata(has_initial_states=has_initial_states, + prep_initial_states=prep_initial_states, + chunk_size=chunk_size, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py new file mode 100644 index 0000000..118bd8d --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch import nn +from torch.nn.parameter import Parameter + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer") +class MambaMixer(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + time_step_rank: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + rms_norm_has_weight: bool = True, + rms_norm_eps: float = 1e-5, + activation="silu", + is_lora_enabled: bool = False): + super().__init__() + self.time_step_rank = time_step_rank + self.ssm_state_size = ssm_state_size + self.use_rms_norm = use_rms_norm + self.activation = activation + self.is_lora_enabled = is_lora_enabled + + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=intermediate_size, + bias=use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=use_bias) + + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + intermediate_size, + time_step_rank + ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(time_step_rank, + intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + intermediate_size // tp_size, + ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + ) + + self.dt_layernorm = RMSNorm( + time_step_rank, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.b_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + self.c_layernorm = RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) if use_rms_norm else None + + def forward_native(self, hidden_states: torch.Tensor, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda(self, hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams): + + attn_metadata: AttentionMetadata = get_forward_context().attn_metadata + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + ssm_parameters = self.x_proj( + hidden_states.transpose(-2, -1).contiguous())[0] + else: + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + if self.use_rms_norm: + assert self.dt_layernorm is not None + assert self.b_layernorm is not None + assert self.c_layernorm is not None + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + if self.is_lora_enabled: + # lora kernel requires contiguous tensor + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1).contiguous())[0] + else: + contextualized_states = self.out_proj( + scan_outputs.transpose(-2, -1))[0] + return contextualized_states diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py new file mode 100644 index 0000000..9dcbcb2 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -0,0 +1,731 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +from torch import nn + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import get_current_vllm_config +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.forward_context import get_forward_context +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_state_update) +from vllm.model_executor.layers.mamba.ops.ssd_combined import ( + mamba_chunk_scan_combined) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + LoaderFunction, composed_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs +from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata + +# Added by the IBM Team, 2024 + + +# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated +@CustomOp.register("mixer2_gated_rms_norm") +class Mixer2RMSNormGated(CustomOp): + + def __init__(self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.use_rms_norm = use_rms_norm + if self.use_rms_norm: + # Register norm weight only if we're actually applying RMSNorm + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, + {"weight_loader": sharded_weight_loader(0)}) + else: + # Avoid checkpoint mismatch by skipping unused parameter + self.register_parameter("weight", None) + assert (self.full_hidden_size % self.tp_size == 0 + ), "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + if not self.use_rms_norm: + return x.to(input_dtype) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = global_sums / count + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + if not self.use_rms_norm: + # Keep gate in float32 for numerical stability during silu + return x * nn.functional.silu(gate.to( + torch.float32)).to(input_dtype) + + if self.tp_size > 1 or self.n_groups != 1: + return self.forward_native(x, gate) + + from vllm import _custom_ops as ops + + # cast x and gate to float32 before silu + out = torch.empty_like(x) + y = x * nn.functional.silu(gate.to(torch.float32)) + ops.rms_norm( + out, + y.to(x.dtype), + self.weight.data, + self.variance_epsilon, + ) + return out + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + +def mamba_v2_sharded_weight_loader( + shard_spec: list[tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures that all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, duplicate_groups in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + # NOTE: currently we only support duplication + # in the case where num_groups == 1 + rank = 0 if duplicate_groups else tp_rank + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[ + boundary:(boundary + take), + ... # type: ignore[misc] + ] = loaded_weight[loaded_start_idx:(loaded_start_idx + + take) # type: ignore[misc] + ] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += full_dim - extra + + return loader + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer2") +class MambaMixer2(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + chunk_size: int = -1, # the chunk size used by v1 + ): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + # - NOTE: currently for the world size DOES NOT divide groups + # case, we only support the case when n_groups == 1 + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert (num_heads % self.tp_size == 0 + ), "Tensor parallel world size must divide num heads." + + assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( + "If tensor parallel world size does not divide num_heads, " + "then num_groups must equal 1.") + + assert ( + self.tp_size == 1 or quant_config is None + ), "Tensor parallel currently not supported for quantized models." + + self.ssm_state_size = ssm_state_size + self.conv_kernel_size = conv_kernel_size + self.activation = activation + + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + self.n_groups = n_groups + extra_groups_for_head_shards( + n_groups, self.tp_size) + + self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + ) + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups to the head shards + group_shard_settings = ( + self.n_groups * self.ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) + + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + tp_rank, + ) + }, + ) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.use_rms_norm = use_rms_norm + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, + {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config, + ) + + self.norm = Mixer2RMSNormGated(intermediate_size, + n_groups, + self.use_rms_norm, + eps=rms_norm_eps) + + if envs.VLLM_USE_V1: + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + # The inner tuple is (conv_state, ssm_state) + self.kv_cache = [(torch.tensor([]), torch.tensor([]))] + assert chunk_size != -1, "chunk_size must be set for v1" + + # NOTE: chunk_size may be -1 for models without v1 support + self.chunk_size = chunk_size + self.prefix = prefix + + def forward_native( + self, + hidden_states: torch.Tensor, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + pass + + def forward_cuda( + self, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, + ): + forward_context = get_forward_context() + # mamba2_metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # stay the same and reused for all mamba layers in the same iteration + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if envs.VLLM_USE_V1: + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, Mamba2AttentionMetadata) + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0] + ssm_state = self_kv_cache[1] + state_indices_tensor = attn_metadata.state_indices_tensor + has_initial_states_p = attn_metadata.has_initial_states + prep_initial_states = attn_metadata.prep_initial_states + chunk_size = attn_metadata.chunk_size + seq_idx_p = attn_metadata.seq_idx + chunk_indices_p = attn_metadata.chunk_indices + chunk_offsets_p = attn_metadata.chunk_offsets + else: + conv_state = mamba_cache_params.conv_state + ssm_state = mamba_cache_params.ssm_state + state_indices_tensor = mamba_cache_params.state_indices_tensor + has_initial_states_p = mamba2_metadata.has_initial_states + prep_initial_states = mamba2_metadata.prep_initial_states + chunk_size = mamba2_metadata.chunk_size + seq_idx_p = mamba2_metadata.seq_idx + chunk_indices_p = mamba2_metadata.chunk_indices + chunk_offsets_p = mamba2_metadata.chunk_offsets + + groups_time_state_size = self.n_groups * self.ssm_state_size + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + + if mup_vector is not None: + projected_states = projected_states * mup_vector + + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + # - get hidden_states, B and C after depthwise convolution. + split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + groups_time_state_size // self.tp_size, + groups_time_state_size // self.tp_size, + ], + dim=-1, + ) + + if envs.VLLM_USE_V1 and attn_metadata is None: + # V1 profile run + hidden_states_B_C = (hidden_states_B_C.transpose( + 0, 1).clone().transpose(0, 1)).contiguous() + hidden_states, _B, _C = split_hidden_states_B_C_fn( + hidden_states_B_C) + hidden_states = self.norm(hidden_states, gate) + out, _ = self.out_proj(hidden_states) + return out + + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + + # NOTE: V0 put prefill before decode, v1 puts decode before prefill + # Separate prefill and decode by splitting varlen input + # Split along token dimension + if envs.VLLM_USE_V1: + hidden_states_B_C_d, hidden_states_B_C_p = torch.split( + hidden_states_B_C, + [num_decodes, num_prefill_tokens], + dim=0, + ) + dt_d, dt_p = torch.split( + dt, + [num_decodes, num_prefill_tokens], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, + [num_decodes, num_prefills], + dim=0, + ) + query_start_loc_p = ( + attn_metadata.query_start_loc[-num_prefills - 1:] - + num_decodes if has_prefill else None) + else: + hidden_states_B_C_p, hidden_states_B_C_d = torch.split( + hidden_states_B_C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + dt_p, dt_d = torch.split( + dt, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + + 1] + if has_prefill else None) + + ssd_output_list = [] + + # Process prefill requests + if has_prefill: + # 2. Convolution sequence transformation + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "state_indices_tensor" + hidden_states_B_C_p = causal_conv1d_fn( + hidden_states_B_C_p.transpose(0, 1), + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p).transpose( + 0, 1)[:num_prefill_tokens] + + # TODO: Why is this needed? + hidden_states_B_C_p = hidden_states_B_C_p.contiguous() + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( + hidden_states_B_C_p) + + # 3. State Space Model sequence transformation + initial_states = None + if (has_initial_states_p is not None and prep_initial_states): + # making a copy of the states + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + + scan_output, varlen_state = mamba_chunk_scan_combined( + hidden_states_p.view(1, num_prefill_tokens, + self.num_heads // self.tp_size, + self.head_dim), + dt_p.unsqueeze(0), + self.A, + B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + -1), + C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, + -1), + chunk_size=chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=seq_idx_p, + chunk_indices=chunk_indices_p, + chunk_offsets=chunk_offsets_p, + cu_seqlens=query_start_loc_p, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + ) + + # update ssm states + # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor + ssm_state[state_indices_tensor_p] = varlen_state + + # - reshape + ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) + + # Process decode requests + if has_decode: + # 2. Convolution sequence transformation + hidden_states_B_C_d = causal_conv1d_update( + hidden_states_B_C_d, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=state_indices_tensor_d) + + hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( + hidden_states_B_C_d) + + # 3. State Space Model sequence transformation + n_groups = self.n_groups // self.tp_size + A_d = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D_d = self.D[:, None, ...].expand(-1, self.head_dim) + B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups) + C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups) + hidden_states_d = hidden_states_d.view( + -1, self.num_heads // self.tp_size, self.head_dim) + + # - the hidden is reshaped into (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using state_indices_tensor_d + + hidden_states_d = selective_state_update( + ssm_state, + hidden_states_d, + dt_d, + A_d, + B_d, + C_d, + D_d, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor_d, + ) + + if envs.VLLM_USE_V1: + ssd_output_list.insert( + 0, + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + else: + ssd_output_list.append( + hidden_states_d.view(-1, (self.num_heads // self.tp_size) * + self.head_dim)) + + # Merge prefill and decode outputs before passing to gated MLP + hidden_states = torch.vstack(ssd_output_list) + + # 4. gated MLP + # GatedRMSNorm internally applying SiLU to the gate + # SiLU is applied internally before normalization, unlike standard + # norm usage + hidden_states = self.norm(hidden_states, gate) + + # 5. Final linear projection + out, _ = self.out_proj(hidden_states) + return out + + def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + world_size = get_tensor_model_parallel_world_size() + + conv_state_shape, temporal_state_shape = None, None + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.n_groups + + extra_groups_for_head_shards(self.n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (self.intermediate_size + + 2 * n_groups * self.ssm_state_size) + conv_state_shape = ( + divide(conv_dim, world_size), + self.conv_kernel_size - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.num_heads, world_size), + self.head_dim, + self.ssm_state_size, + ) + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py new file mode 100644 index 0000000..a10c5ab --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID + + +def causal_conv1d_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, + cache_indices, has_initial_state, activation + in ["silu", "swish"], pad_slot_id) + return x + + +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, + cache_seqlens, conv_state_indices, pad_slot_id) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 0000000..3f67fc3 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,414 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py + +import torch +from packaging import version + +from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import HAS_TRITON, tl, triton + +TRITON3 = HAS_TRITON and (version.parse(triton.__version__) + >= version.parse("3.0.0")) + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({ + "HAS_STATE_BATCH_INDICES": + lambda args: args["state_batch_indices_ptr"] is not None +}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + pad_slot_id, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) + state_ptr += (state_batch_idx * stride_state_batch + + pid_h * stride_state_head) + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + state = tl.load(state_ptrs, mask=mask, other=0.0) + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= (state_batch_idx != pad_slot_id) + tl.store(state_ptrs, state, mask=mask) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch, ) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + pad_slot_id, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn(u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: + """ + u: (dim, total_length) for varlen or (batch, dim, seqlen) + applies changes in place. + ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) + applies changes in place. + delta: (dim, total_length) for varlen or (batch, dim, seqlen) + A: (dim, dstate) + B: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + C: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + D: (dim,) + z: (dim, total_length) for varlen or (batch, dim, seqlen) + dt_bias: (dim,) or (dim) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended with 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + A tensor with each cell is a correspondent + input and output ssm_state index + has_initial_state: (batch) bool + A tensor populated with ones and zeros, + indicate if the ssm_state at the corresponding index should be + used as initial state. Not providing argument assumes + there's no initial state + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 + returns + output: (dim, total_length) for varlen or (batch, dim, seqlen) + supports inplace replacement + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3 and query_start_loc is None: + B = B.unsqueeze(1) + if B.dim() == 2 and query_start_loc is not None: + B = B.unsqueeze(0) + if C.dim() == 3 and query_start_loc is None: + C = C.unsqueeze(1) + if C.dim() == 2 and query_start_loc is not None: + C = C.unsqueeze(0) + + ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, + query_start_loc, cache_indices, has_initial_state, + ssm_states, pad_slot_id) + + if z is None: + return delta # output written inplace to delta + else: + return z # output written inplace to z diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py new file mode 100644 index 0000000..11ca125 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -0,0 +1,262 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'K', 'IS_CAUSAL'], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0).to(dot_dtype) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & + (offs_n[None, :] < chunk_size_limit), + other=0.0).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & + (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, + b, + chunk_size, + seq_idx=None, + causal=False, + output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (batch, nchunks, chunk_size, chunk_size) if not has_groups else + (batch, nchunks, ngroups, chunk_size, chunk_size), + device=a.device, + dtype=out_dtype) + dot_dtype = (tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else + (tl.float16 if a.dtype == torch.float16 + or b.dtype == torch.float16 else tl.float32)) + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + chunk_size, META['BLOCK_SIZE_N']), batch, nchunks + if not has_groups else nchunks * ngroups) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 0000000..365e1c5 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,589 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import torch +from packaging import version + +from vllm.triton_utils import tl, triton + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + mask=c_idx >= 1, + other=0) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, + mask=offs_m < chunk_size, + other=0.0).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) + and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), + other=0.0).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min( + (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < hdim), + other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, + mask=offs_n < hdim, + other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :]) + tl.store(out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim), + other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & + (offs_out_n[None, :] < hdim)) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + initial_states=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + + if initial_states.shape[0] == 1: + # no in this case no point to use initial states + initial_states = None + else: + assert chunk_indices is not None and chunk_offsets is not None, \ + ( + "chunk_indices and chunk_offsets should have been set" + ) + else: + chunk_indices, chunk_offsets = None, None + else: + chunk_indices, chunk_offsets = None, None + + # Allocates output. + out = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, + seqlen, + nheads, + headdim, + device=x.device, + dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), + z.stride(3)) if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else + (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out, out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py new file mode 100644 index 0000000..ad58a99 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,751 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch + +from vllm.triton_utils import tl, triton + +from .mamba_ssm import softplus + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + + offs_c[None, :] * stride_dt_out_csize) + dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load(dt_ptrs, + mask=(offs_h[:, None] < nheads) & + (offs_c[None, :] < chunk_size_limit), + other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, + mask=offs_h < nheads, + other=0.0).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + 0.0) + tl.store(dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load(seq_idx_ptr + + (chunk_size_limit - 1) * stride_seq_idx_seqlen) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load(seq_idx_ptrs, + mask=offs_k < chunk_size_limit - k, + other=-1) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where(seq_idx_k == seq_idx_last, + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * + stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load(x_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_k[None, :] < chunk_size_limit - k) & + (offs_k[None, :] >= start_idx_cur - k), + other=0.0) + b = tl.load(b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & + (offs_n[None, :] < dstate) & + (offs_k[:, None] >= start_idx_cur - k), + other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, + mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + + # - this seems repetitive, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) + + past_states = tl.load(past_states_ptrs, + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + + offs_n[None, :] * stride_states_dstate) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads, ) + if dt_bias is not None: + assert dt_bias.shape == (nheads, ) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + dA_cumsum = torch.empty(batch, + nheads, + nchunks, + chunk_size, + device=dt.device, + dtype=torch.float32) + grid_chunk_cs = lambda META: (batch, nchunks, + triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=None, + states=None, + states_in_fp32=True): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty((batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype) + grid = lambda META: ( + triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( + dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty(batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device) + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. + cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) + return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py new file mode 100644 index 0000000..b121275 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +from einops import rearrange +from packaging import version + +from vllm.triton_utils import triton + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, + chunk_state_varlen) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def _mamba_chunk_scan_combined_fwd(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads, ) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride( + 1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride( + 1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, + headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd(dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, + x, + dt, + dA_cumsum, + seq_idx=seq_idx, + states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx and iii) is_cont_batched to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None else None, + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=C.dtype, + is_cont_batched=cu_seqlens is not None) + states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) + for t in [states, final_states]) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, + B, + chunk_size, + seq_idx=seq_idx, + output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out, out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + initial_states=initial_states, + ) + if cu_seqlens is None: + return out, out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out, out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined(x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + Return: + out: (batch, seqlen, nheads, headdim) + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True" + out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit) + if not return_varlen_states: + return out if not return_final_states else (out, final_states) + else: + varlen_states = rest[0] + return (out, + varlen_states) if not return_final_states else (out, + final_states, + varlen_states) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py new file mode 100644 index 0000000..a28fc9f --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), + ], + key=['dim'], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale = tl.exp(dA_cs) + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_new below. + seq_idx_new = tl.load(seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * + stride_seq_idx_seqlen) + if HAS_INITSTATES: + if IS_CONT_BATCHED and seq_idx != seq_idx_new: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) + else: + scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + + seq_idx = seq_idx_new + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_chunk_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, +): + batch, nchunks, nheads, dim = states.shape + assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, "" + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + assert chunk_size is not None + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((batch, nchunks, nheads, dim), + device=states.device, + dtype=out_dtype) + final_states = torch.empty((batch, nheads, dim), + device=states.device, + dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_chunk_cumsum, + initial_states, + seq_idx, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size if seq_idx is not None else 0, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_chunk_cumsum.stride(0), + dA_chunk_cumsum.stride(2), + dA_chunk_cumsum.stride(1), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2)) if initial_states is not None else + (0, 0, 0)), + *((seq_idx.stride(0), + seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py new file mode 100644 index 0000000..2fbd664 --- /dev/null +++ b/vllm/model_executor/layers/ops/rand.py @@ -0,0 +1,164 @@ +from typing import Optional, Union + +import torch +import triton +import triton.language as tl +from vllm.utils import is_hip + + +def seeded_uniform( + *size, + seeds: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + pin_memory: Optional[bool] = False, +) -> torch.Tensor: + """Similar to torch.rand, but allows for seeds to be set per row. + + seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. + If it is 3d, the additional seeds needed will be derived automatically + in a deterministic fashion: + [ + row 0: [columns_with_seed_0], [columns_with_seed0^1], ... + ] + """ + n_dims = len(size) + + if n_dims > 3: + raise ValueError("seeded_uniform only supports up to 3D tensors") + + if out is None: + out = torch.empty(*size, + dtype=dtype, + device=device, + pin_memory=pin_memory) + elif out.shape != size: + raise ValueError("shape of out and size must be the same") + + if n_dims == 3: + n_rows, n_3d, n_cols = out.shape + stride_row = out.stride(0) + stride_3d = out.stride(1) + elif n_dims == 2: + n_rows, n_cols = out.shape + n_3d = 1 + stride_row = out.stride(0) + stride_3d = 1 + else: + n_cols = out.shape[0] + n_rows = 1 + n_3d = 1 + stride_row = 1 + stride_3d = 1 + + if seeds.ndim != 1: + raise ValueError("seeds must be a 1D tensor") + + if seeds.numel() != n_rows: + raise ValueError( + "seeds must have the same number of elements as out has rows") + + # The philox PRNG Triton uses generates 4 random numbers at once. + # Therefore, the most efficient use of it is to divide the + # block size by 4, and then save the generated random numbers to + # each of the 4 slices of the tensor. + full_block_size = triton.next_power_of_2(n_cols) + philox_block_size = max(full_block_size // 4, 1) + n_slices = full_block_size // philox_block_size + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if philox_block_size >= 8192: + if is_hip(): + num_warps = 16 + else: + num_warps = 32 + elif philox_block_size >= 4096: + if is_hip(): + num_warps = 8 + else: + num_warps = 16 + elif philox_block_size >= 2048: + num_warps = 8 + + _seeded_uniform_triton[(n_rows, n_3d)]( + out, + seeds, + stride_row, + stride_3d, + seeds.stride(0), + n_rows, + n_3d, + n_cols, + n_slices=n_slices, + num_warps=num_warps, + block_size=philox_block_size, + ) + return out + + +@triton.jit +def _seeded_uniform_triton( + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, + block_size: tl.constexpr, +): + """ + Generate a random float32 number in [0, 1) for each element in the output + tensor. The random numbers in a row generated using the seed for that row. + + Args: + out_ptr: The output tensor. + seed_ptr: The per-row seeds to use for random number generation. + out_row_stride: The stride between rows of the output tensor. + out_3d_stride: The stride between 3D slices of the output tensor. + seed_row_stride: The stride between rows of the seed tensor. + n_rows: The number of rows in the output tensor. + n_3d: The size of second dimension of the output tensor, + if output tensor is 3D. + n_cols: The number of columns in the output tensor. + n_slices: The number of philox outputs to use. + """ + tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") + + # Get the row index. + row_idx = tl.program_id(axis=0) + three_d_idx = tl.program_id(axis=1) + + philox_offsets = tl.arange(0, block_size) + # Get the seed for the current element. + seed = tl.load(seed_ptr + row_idx * seed_row_stride) + if three_d_idx > 0: + seed ^= three_d_idx + # Generate random numbers in [0, 1). + out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) + + output_row_start_ptr = (out_ptr + row_idx * out_row_stride + + three_d_idx * out_3d_stride) + out1_offsets = philox_offsets + tl.store(output_row_start_ptr + out1_offsets, + out1, + mask=out1_offsets < n_cols) + if n_slices > 1: + out2_offsets = tl.arange(block_size, block_size * 2) + tl.store(output_row_start_ptr + out2_offsets, + out2, + mask=out2_offsets < n_cols) + if n_slices > 2: + out3_offsets = tl.arange(block_size * 2, block_size * 3) + tl.store(output_row_start_ptr + out3_offsets, + out3, + mask=out3_offsets < n_cols) + if n_slices > 3: + out4_offsets = tl.arange(block_size * 3, block_size * 4) + tl.store(output_row_start_ptr + out4_offsets, + out4, + mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py new file mode 100644 index 0000000..2fc5010 --- /dev/null +++ b/vllm/model_executor/layers/ops/sample.py @@ -0,0 +1,401 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.rand import seeded_uniform +from vllm.triton_utils.sample import get_num_triton_sampler_splits +from vllm.utils import is_hip + +_EPS: tl.constexpr = 1e-6 + + +def _multi_split_sample( + probs: torch.Tensor, + seeds: torch.Tensor, + n_splits: int, + sampled_tokens_size: Tuple[int, int], + sampled_logprobs_size: Tuple[int, int], + sample_indices: torch.Tensor, + logprobs: torch.Tensor, + *, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, +): + """Sample tokens where vocab size is split into multiple parts + (too large for Triton otherwise).""" + assert seeds.ndim == 2 and seeds.shape[0] == n_splits + split_probs = probs.tensor_split(n_splits, 1) + split_logprobs = logprobs.tensor_split(n_splits, 1) + sampled_tokens_tmp = [ + torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) + for _ in range(n_splits) + ] + sampled_logprobs_tmp = [ + torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + # We are purposefuly using sampled_tokens_size as we need to always + # save modified probs in this case. + sampled_modified_probs_tmp = [ + torch.empty(sampled_tokens_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + for i in range(n_splits): + n_samples = sample_indices.shape[0] + n_cols = split_probs[i].shape[1] + n_best = sampled_tokens_tmp[i].shape[1] + uniform_noise = seeded_uniform(n_samples, + n_best, + n_cols, + seeds=seeds[i].flatten(), + device=split_probs[i].device, + dtype=split_probs[i].dtype) + # TODO(yard1): See if we can remove the contiguous() calls. + # Will need kernel support. + _sample( + split_probs[i].contiguous(), + split_logprobs[i].contiguous(), + sample_indices, + sampled_tokens_tmp[i], + sampled_logprobs_tmp[i], + sampled_modified_probs_tmp[i], + seeds[i], + uniform_noise, + modify_greedy_probs=False, + save_logprobs=save_logprobs, + save_modified_probs=True, + ) + if i > 0: + # Add offset to sampled tokens + sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) + sampled_tokens = torch.stack(sampled_tokens_tmp) + sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) + # Reduce the results from the splits. + sampled_modified_probs, indices = torch.max(sampled_modified_probs, + dim=0, + keepdim=True) + sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) + if save_logprobs: + sampled_logprobs = torch.stack(sampled_logprobs_tmp) + sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) + else: + sampled_logprobs = None + sampled_modified_probs = sampled_modified_probs.squeeze(0) + + if modify_greedy_probs: + # We need to modify the greedy probs for the sampled tokens. + # We can't do this in the kernel as we need to know the + # sampled tokens. + probs.fill_(0.0) + probs.scatter_(1, sampled_tokens, 1.0) + + return (sampled_tokens, sampled_logprobs, sampled_modified_probs) + + +def sample( + probs: torch.Tensor, + seeds: torch.Tensor, + *, + max_best_of: int = 1, + sample_indices: Optional[torch.Tensor] = None, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, + _save_modified_probs: bool = False, # pylint: disable=invalid-name +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Sample tokens from probs. with per-sequence seeds. + + Can sample from a subset of sequences through sample_indices. + + Args: + probs: Probabilities to sample from. + shape = [batch_size, vocab_size] + seeds: Per-sequence seed values. + shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] + max_best_of: Number of samples to generate per sequence. + Sequence seed will be incremented by 1 each time. + sample_indices: Indices of sequences to sample from. + If not provided, will sample from all sequences. + shape = [n] + logprobs: Log-probabilities of the sampled tokens. + Only used for saving the logprobs if save_logprobs is True. + shape = [batch_size, vocab_size] + modify_greedy_probs: Whether to modify the greedy probabilities + for speculative sampling (sampled token = 1.0, + everything else = 0.0). + save_logprobs: Whether to save the log-probabilities of the + sampled tokens to a tensor. + _save_modified_probs: Whether to save the modified probabilities + (including gumbel noise) of the sampled tokens to a tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + This is exposed only for testing. + + Returns: + sampled_tokens: shape = [n, max_best_of] + sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None + sampled_modified_probs: shape = [n, max_best_of] + if save_modified_probs else None + """ + if sample_indices is None: + sample_indices = torch.arange(0, probs.shape[0], device=probs.device) + + sampled_tokens_size = (sample_indices.size(0), max_best_of) + if save_logprobs: + if logprobs is None: + raise ValueError( + "logprobs tensor must be provided if save_logprobs is True") + sampled_logprobs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_logprobs_size = (0, 0) + logprobs = probs + + assert logprobs is not None + if _save_modified_probs: + sampled_modified_probs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_modified_probs_size = (0, 0) + + # If the number of columns in probs is too large for Triton to handle, + # we split the tensor and sample from each split separately, and then + # do an argmax+gather to combine the results. + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if n_splits > 1: + (sampled_tokens, sampled_logprobs, + sampled_modified_probs) = _multi_split_sample( + probs, + seeds, + n_splits, + sampled_tokens_size, + sampled_logprobs_size, + sample_indices, + logprobs=logprobs, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs) + else: + sampled_tokens = torch.empty(sampled_tokens_size, + dtype=torch.long, + device=probs.device) + sampled_logprobs = torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) + sampled_modified_probs = torch.empty(sampled_modified_probs_size, + dtype=probs.dtype, + device=probs.device) + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + uniform_noise = seeded_uniform(n_samples, + max_best_of, + n_cols, + seeds=seeds.flatten(), + device=probs.device, + dtype=probs.dtype) + + _sample( + probs, + logprobs, + sample_indices, + sampled_tokens, + sampled_logprobs, + sampled_modified_probs, + seeds, + uniform_noise, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=_save_modified_probs, + ) + return (sampled_tokens, sampled_logprobs if save_logprobs else None, + sampled_modified_probs if _save_modified_probs else None) + + +def _sample(probs: torch.Tensor, + logprobs: torch.Tensor, + sample_indices: torch.Tensor, + output_samples: torch.Tensor, + output_logprobs: torch.Tensor, + output_modified_probs: torch.Tensor, + seeds: torch.Tensor, + uniform_noise: torch.Tensor, + *, + modify_greedy_probs: bool = False, + save_logprobs: bool = True, + save_modified_probs: bool = False) -> torch.Tensor: + """Sample tokens from probs. + + Args: + probs [batch_size, vocab_size]: probs to sample from. + logprobs [batch_size, vocab_size]: logprobs (used when + save_logprobsis True). + sample_indices [n]: Indices of the samples to use for each row of probs. + output_samples [n, n_best]: Output tensor to store samples in. + output_logprobs [n, n_best]: Output tensor to store logprobs in. + output_modified_probs [n, n_best]: Output tensor to store + probs of chosen tokens in (modified with noise). + seeds [n]: Seeds to use for sampling. If the seed is 0, we use + greedy sampling. Note this is ONLY used for determining + whether to use random sampling or not. The actual random + noise should be passed as uniform_noise. + uniform_noise [batch_size, n_best, vocab_size]: Uniform + noise to use for random sampling (will be converted + to exponential gumbel noise by the kernel). + modify_greedy_probs: If True, we modify the probs tensor in-place + to encode the sampling method used for each row. This is used + in speculative decoding. Only applies in greedy decoding. + save_logprobs: If True, we save the logprobs of the sampled tokens + in the output_logprobs tensor. + save_modified_probs: If True, we save the modified probs (with noise) + of the sampled tokens in the output_modified_probs tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + """ + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 + + # The block size is the smallest power of two greater than the number of + # columns in probs + block_size = triton.next_power_of_2(n_cols) + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if block_size >= 8192: + if is_hip(): + num_warps = 16 + else: + num_warps = 32 + elif block_size >= 4096: + if is_hip(): + num_warps = 8 + else: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + + # Enqueue kernel. The 1D launch grid is simple: we have one kernel + # instance per row of the probs matrix + _sample_triton[(n_samples, n_best)]( + sample_indices, + output_samples, + output_logprobs, + output_modified_probs, + probs, + logprobs, + seeds, + uniform_noise, + output_samples.stride(0), + probs.stride(0), + uniform_noise.stride(0), + uniform_noise.stride(1) if n_best > 1 else 1, + n_samples, + n_cols, + n_best, + num_warps=num_warps, + block_size=block_size, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=save_modified_probs, + ) + return output_samples, output_logprobs, output_modified_probs + + +@triton.jit +def _uniform_to_exponential(uniform_noise): + """Convert uniform samples to exponential samples.""" + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by 0 later + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + exponential_noise = -tl.log(uniform_noise) + return exponential_noise + + +@triton.jit +def _sample_triton( + sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, + output_logprobs_ptr: torch.Tensor, + output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, output_row_stride: int, + probs_row_stride: int, uniform_noise_row_stride: int, + uniform_noise_best_stride: int, n_samples: int, n_cols: int, + n_best: int, block_size: tl.constexpr, + modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, + save_modified_probs: tl.constexpr): + # The rows are independent, so we parallelize across those + sample_idx = tl.program_id(0) + best_idx = tl.program_id(1) + + # Load the row index from DRAM + row_idx = tl.load(sample_indices_ptr + sample_idx) + seed = tl.load(seeds_ptr + sample_idx) + uses_random_sampling = seed != 0 + + # The stride represents how much we need to increase the + # pointer to advance 1 row + row_start_ptr = probs_ptr + row_idx * probs_row_stride + + # The block size is the next power of two greater than n_cols, + # so we can fit each row in a single block + col_offsets = tl.arange(0, block_size) + + # Load the row into SRAM, using a mask since block_size may be > than n_cols + row = tl.load(row_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=float("-inf")) + + if uses_random_sampling: + uniform_noise_start_ptr = (uniform_noise_ptr + + sample_idx * uniform_noise_row_stride + + best_idx * uniform_noise_best_stride) + uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=0.5) + exponential_noise = _uniform_to_exponential(uniform_noise) + row /= exponential_noise + + sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) + # clamp sampled token to n_cols - 1 + # this should not be necessary, but we do it + # just in case + if sampled_token >= n_cols: + sampled_token = n_cols - 1 + # Write back output to DRAM + output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + + best_idx) + tl.store(output_row_start_ptr, sampled_token) + + if modify_greedy_probs: # noqa + if not uses_random_sampling: + # Set the probability of the sampled token to 1, all other + # tokens to zero. This is used in speculative decoding where + # the sampling method must be encoded within the sampled + # probability distributions. + row = tl.where(col_offsets == sampled_token, 1.0, 0.0) + tl.store(row_start_ptr + col_offsets, + row, + mask=col_offsets < n_cols) + + if save_modified_probs: + output_row_start_ptr = (output_modified_probs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_value) + + if save_logprobs: + # Load the row into SRAM, using a mask since block_size + # may be > than n_cols + sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + + sampled_token) + # Write back output to DRAM + output_row_start_ptr = (output_logprobs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 0000000..d864a91 --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,473 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from enum import IntEnum +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing_extensions import assert_never + +from vllm.config import ModelConfig, PoolerConfig +from vllm.model_executor.pooling_metadata import ( # noqa: E501 + PoolingMetadata as V0PoolingMetadata) +from vllm.model_executor.pooling_metadata import PoolingTensors +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.transformers_utils.config import ( + get_classification_activation_function, + get_cross_encoder_activation_function) +from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata + +PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + ALL = 1 + CLS = 2 + STEP = 3 + MEAN = 4 + + +class SimplePooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + @staticmethod + def from_pooling_type( + pooling_type: PoolingType, + *, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[list[int]] = None, + ) -> "SimplePooler": + if pooling_type == PoolingType.LAST: + assert step_tag_id is None and returned_token_ids is None + return LastPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.ALL: + assert step_tag_id is None and returned_token_ids is None + return AllPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.CLS: + assert step_tag_id is None and returned_token_ids is None + return CLSPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.MEAN: + assert step_tag_id is None and returned_token_ids is None + return MeanPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.STEP: + return StepPool(normalize=normalize, + softmax=softmax, + step_tag_id=step_tag_id, + returned_token_ids=returned_token_ids) + + assert_never(pooling_type) + + def __init__(self, *, normalize: bool, softmax: bool) -> None: + super().__init__() + + self.head = PoolerHead(normalize=normalize, softmax=softmax) + + def get_prompt_lens( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + assert isinstance(hidden_states, torch.Tensor) + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + raise NotImplementedError + + def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput: + return PoolingSequenceGroupOutput(data) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + pooled_outputs = [self.build_output(data) for data in pooled_data] + return PoolerOutput(outputs=pooled_outputs) + + +class CLSPool(SimplePooler): + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + if isinstance(hidden_states, list): + result = [] + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with CLS pooling" + result.append(req_state[0]) + return result + + first_token_flat_indices = torch.zeros_like(prompt_lens) + first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] + return hidden_states[first_token_flat_indices] + + +class LastPool(SimplePooler): + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + if isinstance(hidden_states, list): + return [h[-1] for h in hidden_states] + + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + return hidden_states[last_token_flat_indices] + + +class AllPool(SimplePooler): + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + if isinstance(hidden_states, list): + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with ALL pooling" + return hidden_states + + offset = 0 + pooled_data = list[torch.Tensor]() + for prompt_len in prompt_lens: + pooled_data.append(hidden_states[offset:offset + prompt_len]) + offset += prompt_len + + return pooled_data + + +class MeanPool(SimplePooler): + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + if isinstance(hidden_states, list): + result = [] + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with mean pooling" + result.append(torch.mean(req_state, dim=0, + dtype=torch.float32)) + return result + + # Use float32 for torch.cumsum in MeanPool, + # otherwise precision will be lost significantly. + cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) + + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + return (cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + + +class StepPool(SimplePooler): + + def __init__( + self, + *, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[list[int]] = None, + ): + super().__init__(normalize=normalize, softmax=softmax) + + self.step_tag_id = step_tag_id + self.returned_token_ids = returned_token_ids + + def get_prompt_token_ids( + self, + pooling_metadata: PoolingMetadata, + ) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) + ] + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) + + pooled_data_lst = list[torch.Tensor]() + if isinstance(hidden_states, list): + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with step pooling" + pooled_data_lst = hidden_states + else: + offset = 0 + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + offset += prompt_len + pooled_data_lst.append(pooled_data_i) + + pooled_data = list[torch.Tensor]() + returned_token_ids = self.returned_token_ids + step_tag_id = self.step_tag_id + + for data, token_id in zip(pooled_data_lst, prompt_token_ids): + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) + + return pooled_data + + +class PoolerHead(nn.Module): + + def __init__(self, *, normalize: bool, softmax: bool) -> None: + super().__init__() + + self.normalize = normalize + self.softmax = softmax + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata): + + # Using float32 in PoolerHead + if isinstance(pooled_data, list): + for i in range(len(pooled_data)): + pooled_data[i] = pooled_data[i].to(torch.float32) + else: + pooled_data = pooled_data.to(torch.float32) + + # for matryoshka representation + if isinstance(pooling_metadata, V0PoolingMetadata): + dimensions_list = [ + pooling_param.dimensions + for _, pooling_param in pooling_metadata.seq_groups + ] + else: + assert isinstance(pooled_data, list) + dimensions_list = [ + pooling_param.dimensions + for pooling_param in pooling_metadata.pooling_params + ] + if any(d is not None for d in dimensions_list): + # change the output dimension + assert len(pooled_data) == len(dimensions_list) + if len(set(dimensions_list)) == 1 and not isinstance( + pooled_data, list): + # if all dimensions are the same + d = dimensions_list[0] + pooled_data = pooled_data[..., :d] + else: + pooled_data = [ + vecs if d is None else vecs[..., :d] + for vecs, d in zip(pooled_data, dimensions_list) + ] + + if self.normalize: + if isinstance(pooled_data, list): + pooled_data = [ + F.normalize(data, p=2, dim=-1) for data in pooled_data + ] + else: + pooled_data = F.normalize(pooled_data, p=2, dim=-1) + + if self.softmax: + if isinstance(pooled_data, list): + pooled_data = [ + F.softmax(data, dim=-1) + if data.shape[-1] >= 2 else F.sigmoid(data) + for data in pooled_data + ] + else: + if pooled_data.shape[-1] >= 2: + pooled_data = F.softmax(pooled_data, dim=-1) + else: + pooled_data = F.sigmoid(pooled_data) + + # shape: + # classify (& score) -> (batch_size, num_classes) + # embed -> (batch_size, embedding_dim) or list(embedding_dim) + # (batch_size, dimensions) or list(dimensions) if using MRL + return pooled_data + + +class Pooler(nn.Module): + + @classmethod + def from_config_with_defaults( + cls, + pooler_config: PoolerConfig, + pooling_type: PoolingType, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[list[int]] = None, + ) -> SimplePooler: + return SimplePooler.from_pooling_type( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type, + normalize=pooler_config.normalize + if pooler_config.normalize is not None else normalize, + softmax=pooler_config.softmax + if pooler_config.softmax is not None else softmax, + step_tag_id=pooler_config.step_tag_id + if pooler_config.step_tag_id is not None else step_tag_id, + returned_token_ids=pooler_config.returned_token_ids + if pooler_config.returned_token_ids is not None else + returned_token_ids, + ) + + +class ClassifierPooler(nn.Module): + """A pooling layer for classification tasks. + + This layer does the following: + 1. Applies a classification layer to the hidden states. + 2. Optionally applies a pooler layer. + 3. Applies an activation function to the output. In the case of + classification models it is either sigmoid or softmax. In the + case of scoring models, the same behavior is configuration + dependent, as in the sentence-transformers library. + """ + + def __init__( + self, + config: ModelConfig, + classifier: nn.Module, + pooler: Optional[nn.Module] = None, + ): + super().__init__() + self.classifier = classifier + self.pooler = pooler + + self.classification_act_fn = get_classification_activation_function( + config.hf_config) + self.cross_encoder_act_fn = get_cross_encoder_activation_function( + config.hf_config) + + def _get_act_fn(self, use_cross_encoder: bool): + return (self.cross_encoder_act_fn + if use_cross_encoder else self.classification_act_fn) + + def get_prompt_lens( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + assert isinstance(hidden_states, torch.Tensor) + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools sentence pair scores from the hidden_states.""" + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + pooled_data = list[torch.Tensor]() + if isinstance(hidden_states, list): + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with classifier" + pooled_data = hidden_states + else: + offset = 0 + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + offset += prompt_len + pooled_data.append(pooled_data_i) + + pooled_data_lst = [] + for pooled_data_i in pooled_data: + + if self.pooler is not None: + final_shape_tensor = self.pooler(pooled_data_i) + else: + final_shape_tensor = self.classifier(pooled_data_i) + + pooled_data_lst.append(final_shape_tensor) + + pooled_output = torch.stack(pooled_data_lst) + + if self.pooler is not None: + # apply classifier once on the full batch if possible + pooled_output = self.classifier(pooled_output) + + if isinstance(pooling_metadata, V0PoolingMetadata): + use_cross_encoder_list = [ + pooling_param.use_cross_encoder + for _, pooling_param in pooling_metadata.seq_groups + ] + else: + use_cross_encoder_list = [ + pooling_param.use_cross_encoder + for pooling_param in pooling_metadata.pooling_params + ] + + # shape of scores: (batch_size, num_labels) + if all(use_cross_encoder == use_cross_encoder_list[0] + for use_cross_encoder in use_cross_encoder_list): + act_fn = self._get_act_fn(use_cross_encoder_list[0]) + scores = act_fn(pooled_output) + else: + scores = torch.stack([ + self._get_act_fn(use_cross_encoder)(vecs) + for use_cross_encoder, vecs in zip(use_cross_encoder_list, + pooled_output) + ]) + + pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py new file mode 100644 index 0000000..b244fbc --- /dev/null +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Literal, get_args + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +QuantizationMethods = Literal[ + "aqlm", + "awq", + "deepspeedfp", + "tpu_int8", + "fp8", + "ptpc_fp8", + "fbgemm_fp8", + "modelopt", + "modelopt_fp4", + "marlin", + "bitblas", + "gguf", + "gptq_marlin_24", + "gptq_marlin", + "gptq_bitblas", + "awq_marlin", + "gptq", + "compressed-tensors", + "bitsandbytes", + "qqq", + "hqq", + "experts_int8", + "neuron_quant", + "ipex", + "quark", + "moe_wna16", + "torchao", + "auto-round", + "rtn", + "blockwise_int8", + "slimquant_w4a8", + "slimquant_w4a8_marlin" +] +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def register_quantization_config(quantization: str): + """Register a customized vllm quantization config. + + When a quantization method is not supported by vllm, you can register a customized + quantization config to support it. + + Args: + quantization (str): The quantization method name. + + Examples: + >>> from vllm.model_executor.layers.quantization import register_quantization_config + >>> from vllm.model_executor.layers.quantization import get_quantization_config + >>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + >>> + >>> @register_quantization_config("my_quant") + ... class MyQuantConfig(QuantizationConfig): + ... pass + >>> + >>> get_quantization_config("my_quant") + + """ # noqa: E501 + + def _wrapper(quant_config_cls): + if quantization in QUANTIZATION_METHODS: + raise ValueError( + f"The quantization method `{quantization}` is already exists.") + if not issubclass(quant_config_cls, QuantizationConfig): + raise ValueError("The quantization config must be a subclass of " + "`QuantizationConfig`.") + _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls + QUANTIZATION_METHODS.append(quantization) + return quant_config_cls + + return _wrapper + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + # lazy import to avoid triggering `torch.compile` too early + from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + + from .aqlm import AQLMConfig + from .auto_round import AutoRoundConfig + from .awq import AWQConfig + from .awq_marlin import AWQMarlinConfig + from .bitblas import BitBLASConfig + from .bitsandbytes import BitsAndBytesConfig + from .compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) + from .deepspeedfp import DeepSpeedFPConfig + from .experts_int8 import ExpertsInt8Config + from .fbgemm_fp8 import FBGEMMFp8Config + from .fp8 import Fp8Config + from .gguf import GGUFConfig + from .gptq import GPTQConfig + from .gptq_bitblas import GPTQBitBLASConfig + from .gptq_marlin import GPTQMarlinConfig + from .gptq_marlin_24 import GPTQMarlin24Config + from .hqq_marlin import HQQMarlinConfig + from .ipex_quant import IPEXConfig + from .marlin import MarlinConfig + from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config + from .moe_wna16 import MoeWNA16Config + from .neuron_quant import NeuronQuantConfig + from .ptpc_fp8 import PTPCFp8Config + from .qqq import QQQConfig + from .rtn import RTNConfig + from .torchao import TorchAOConfig + from .tpu_int8 import Int8TpuConfig + from .blockwise_int8 import BlockInt8Config + from .slimquant_w4a8 import SlimQuantW4A8Int8Config + from .slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig + + method_to_config: dict[str, type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, + "modelopt_fp4": ModelOptNvFp4Config, + "marlin": MarlinConfig, + "bitblas": BitBLASConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "gptq_bitblas": GPTQBitBLASConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "ptpc_fp8": PTPCFp8Config, + "qqq": QQQConfig, + "hqq": HQQMarlinConfig, + "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, + "quark": QuarkConfig, + "moe_wna16": MoeWNA16Config, + "torchao": TorchAOConfig, + "auto-round": AutoRoundConfig, + "rtn": RTNConfig, + "blockwise_int8": BlockInt8Config, + "slimquant_w4a8":SlimQuantW4A8Int8Config, + "slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, + } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +__all__ = [ + "QuantizationConfig", + "QuantizationMethods", + "get_quantization_config", + "QUANTIZATION_METHODS", +] \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py new file mode 100644 index 0000000..2ea8c5d --- /dev/null +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -0,0 +1,376 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Supports AQLM compression, see https://github.com/Vahe1994/AQLM +# and https://arxiv.org/pdf/2401.06118.pdf + +import math +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +def get_int_dtype(nbits: int) -> torch.dtype: + if nbits <= 8: + return torch.int8 + if nbits <= 16: + return torch.int16 + if nbits <= 32: + return torch.int32 + if nbits <= 64: + return torch.int64 + raise ValueError(f"No dtype available for {nbits}-bit codebooks") + + +@torch.inference_mode() +def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: + return data.to(torch.int64) % (2**nbits) + + +def dequantize_weight(codes: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape + [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, + [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be + broadcastble with + [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape + [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = \ + codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, + device=codes.device) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, + codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size + # * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + + [num_out_groups, num_in_groups, out_group_size, in_group_size]) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( + scales) + return reconstructed_weight_groupwise.swapaxes( + -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + +def dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + dequantized_weight = dequantize_weight( + unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), + codebooks, + scales, + ) + return F.linear(input, dequantized_weight, bias) + + +# Generic dequantization, slow but flexible. +def generic_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: list[int], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output_shape = input.shape[:-1] + (scales.shape[0], ) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + num_outputs = len(output_partition_sizes) + + # break the inputs and codebooks apart then combine the outputs. + # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big + # multiply at the end. + num_codebooks = codebooks.shape[0] // num_outputs + assert (scales.shape[0] == codes.shape[0]) + assert (sum(output_partition_sizes) == scales.shape[0]) + output_offset = 0 + codebooks_offset = 0 + for output_size in output_partition_sizes: + shard_output = dequantize_gemm( + input, codes.narrow(0, output_offset, output_size), + codebooks.narrow(0, codebooks_offset, num_codebooks), + scales.narrow(0, output_offset, output_size), None + if bias is None else bias.narrow(0, output_offset, output_size)) + + output_slice = output.narrow(-1, output_offset, output_size) + assert (output_slice.shape == shard_output.shape) + output_slice.copy_(shard_output) + output_offset += output_size + codebooks_offset += num_codebooks + return output + + +# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 +# at 6 and 9 times faster than the generic version above, respectively. +def optimized_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: list[int], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + # scaling the output is fastest, so we do that when possible. + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +class AQLMConfig(QuantizationConfig): + """Config class for AQLM. + + Reference: https://github.com/Vahe1994/AQLM + """ + + def __init__( + self, + in_group_size: int, + nbits_per_codebook: int, + num_codebooks: int, + out_group_size: int, + ) -> None: + super().__init__() + self.in_group_size = in_group_size + self.nbits_per_codebook = nbits_per_codebook + self.num_codebooks = num_codebooks + self.out_group_size = out_group_size + + # out_group_size > 1 is untested, and probably won't work as-is. + assert (self.out_group_size == 1) + self.pack_factor = (self.in_group_size * self.out_group_size) + + def __repr__(self) -> str: + return (f"AQLMConfig(in_group_size={self.in_group_size}, " + f"nbits_per_codebook={self.nbits_per_codebook}, " + f"num_codebooks={self.num_codebooks}, " + f"out_group_size={self.out_group_size})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "aqlm" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AQLMConfig": + in_group_size = cls.get_from_keys(config, ["in_group_size"]) + nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) + num_code_books = cls.get_from_keys(config, ["num_codebooks"]) + out_group_size = cls.get_from_keys(config, ["out_group_size"]) + return cls(in_group_size, nbits_per_codebook, num_code_books, + out_group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AQLMLinearMethod"]: + if isinstance(layer, LinearBase): + return AQLMLinearMethod(self) + return None + + +class AQLMLinearMethod(LinearMethodBase): + """Linear method for AQLM. + + Args: + quant_config: The AQLM quantization config. + """ + + def __init__(self, quant_config: AQLMConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.half: + raise ValueError("Only half is currently supported by aqlm") + if input_size_per_partition % self.quant_config.in_group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.out_group_size != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + codes = Parameter( + torch.empty( + # There could actually be two pack factors, one along input and + # one along output, but we don't currently support + # out_group_size, and only the one along output needs to be + # marked with "packed_dim" in order for QKVLinear to work. + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + self.quant_config.num_codebooks, + dtype=get_int_dtype(self.quant_config.nbits_per_codebook), + ), + requires_grad=False, + ) + + set_weight_attrs( + codes, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + codebooks = Parameter( + torch.empty( + self.quant_config.num_codebooks * len(output_partition_sizes), + 2**self.quant_config.nbits_per_codebook, + self.quant_config.out_group_size, + self.quant_config.in_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + codebooks, + { + # metadata indicates fixed size concatenated along dim 0 + "is_metadata": True, + "output_partition_sizes": output_partition_sizes + }, + ) + + scales = Parameter( + torch.empty( + ( + output_size_per_partition // + self.quant_config.out_group_size, + 1, + 1, + 1, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 0, + "packed_dim": 0, + "pack_factor": self.quant_config.out_group_size + }, + ) + + layer.register_parameter("codes", codes) + set_weight_attrs(codes, extra_weight_attrs) + layer.register_parameter("codebooks", codebooks) + set_weight_attrs(codebooks, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + codebooks = layer.codebooks + codes = layer.codes + scales = layer.scales + output_partition_sizes = getattr(codebooks, "output_partition_sizes", + []) + + nbooks = codes.shape[2] + ingroups = codebooks.shape[3] + outgroups = codebooks.shape[2] + bits = codebooks.shape[1] + + # We support these formats with dedicated gemm and decompression + # kernels. + if ingroups == 8 and outgroups == 1 and ( + (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): + + # thresholds determined by timings on an A6000, one GPU + use_gemv = math.prod(x.shape[:-1]) <= 6 + + return ops.aqlm_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) if use_gemv else optimized_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) + + # fall back all unoptimized formats + return generic_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py new file mode 100644 index 0000000..ea17cd5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -0,0 +1,310 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from fractions import Fraction +from typing import Any, Optional, Union + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class AutoRoundConfig(QuantizationConfig): + """Config class for AutoRound. + Reference: https://arxiv.org/pdf/2309.05516 + """ + + SUPPORTED_BITS = {2, 3, 4, 8} + SUPPORTED_DTYPES = {"int"} + SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} + SUPPORTED_BACKENDS = { + "auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex" + } + + def __init__( + self, + weight_bits: int, + group_size: int, + sym: bool = True, + packing_format: str = "auto_round:auto_gptq", + block_name_to_quantize: Optional[Union[str, list[str]]] = None, + extra_config: Optional[dict[str, Any]] = None, + data_type: str = "int", + backend: str = "auto", + ) -> None: + super().__init__() + if weight_bits not in self.SUPPORTED_BITS: + raise ValueError(f"Unsupported weight_bits: {weight_bits}, " + f"currently only support {self.SUPPORTED_BITS}") + if data_type not in self.SUPPORTED_DTYPES: + raise ValueError( + f"Unsupported data_type: {data_type}," + f" currently only support {self.SUPPORTED_DTYPES}") + if packing_format not in self.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported packing_format: {packing_format}, " + f"currently only support {self.SUPPORTED_FORMATS}") + if backend not in self.SUPPORTED_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}, " + f"currently only support {self.SUPPORTED_BACKENDS}") + + self.weight_bits = weight_bits + self.group_size = group_size + self.sym = sym + self.packing_format = packing_format + self.block_name_to_quantize = (block_name_to_quantize.split(",") if + isinstance(block_name_to_quantize, str) + else block_name_to_quantize) + self.extra_config = extra_config + self.data_type = data_type + self.backend = backend + self.pack_factor = Fraction(32, weight_bits) + + def __repr__(self) -> str: + return (f"AutoRoundConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, sym={self.sym})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "auto-round" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantization_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": + return cls( + weight_bits=cls.get_from_keys(config, ["bits"]), + group_size=cls.get_from_keys(config, ["group_size"]), + sym=cls.get_from_keys(config, ["sym"]), + packing_format=cls.get_from_keys_or(config, ["packing_format"], + "auto_round:auto_gptq"), + block_name_to_quantize=cls.get_from_keys_or( + config, ["block_name_to_quantize", "to_quant_block_names"], + None), + extra_config=cls.get_from_keys_or(config, ["extra_config"], None), + data_type=cls.get_from_keys_or(config, ["data_type"], "int"), + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], + "auto"), + ) + + def get_layer_config(self, layer, layer_name: str): + # Priority: extra_config > block_name_to_quantize > type fallback + if self.extra_config and layer_name in self.extra_config: + cfg = self.extra_config[layer_name] + return cfg.get("bits", self.weight_bits), cfg.get( + "group_size", self.group_size), cfg.get("sym", self.sym) + + quantized = True + if self.block_name_to_quantize: + quantized = any( + layer_name.startswith(name) + for name in self.block_name_to_quantize) + elif isinstance(layer, ParallelLMHead): + quantized = False + + return (self.weight_bits, self.group_size, + self.sym) if quantized else (16, -1, True) + + def check_quantized(self, weight_bits: int) -> bool: + return weight_bits < 16 + + def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer) + + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, layer.__class__.__name__, weight_bits, group_size, + sym) + if backend == "auto" or "marlin" in backend: + AWQ_TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + use_marlin = (weight_bits + in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym) + + if isinstance(layer, FusedMoE): + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) + + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod) + quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + lm_head_quantized=False, + full_config={}, + modules_to_not_convert=[]) + else: + from vllm.model_executor.layers.quantization.awq import ( + AWQConfig, AWQLinearMethod) + quant_args = AWQConfig( + weight_bits=weight_bits, + group_size=group_size, + zero_point=not sym, + ) + + if isinstance(layer, FusedMoE): + if use_marlin: + return AWQMoEMethod(quant_args_marlin) + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + config = { + "quant_method": "awq", + "bits": weight_bits, + "group_size": group_size, + "zero_point": not sym, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return AWQMarlinLinearMethod(quant_args_marlin) + else: + return AWQLinearMethod(quant_args) + return None + + def apply_gptq_quant_layer(self, + layer, + prefix: str, + backend: str = "auto"): + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer) + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", + prefix, layer.__class__.__name__, weight_bits, group_size, + sym) + if backend == "auto" or "marlin" in backend: + GPTQ_TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP + and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], + group_size, + has_zp=not sym)) + if isinstance(layer, FusedMoE): + use_marlin = use_marlin and check_moe_marlin_supports_layer( + layer, group_size) + else: + use_marlin = False + if use_marlin: + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod) + quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits, + group_size=group_size, + is_sym=sym, + lm_head_quantized=False, + desc_act=False, + dynamic={}, + full_config={}) + else: + from vllm.model_executor.layers.quantization.gptq import ( + GPTQConfig, GPTQLinearMethod) + quant_args = GPTQConfig(weight_bits=weight_bits, + group_size=group_size, + lm_head_quantized=False, + desc_act=False, + dynamic={}) + + if isinstance(layer, FusedMoE): + if use_marlin: + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + config = { + "quant_method": "gptq", + "bits": weight_bits, + "group_size": group_size, + "sym": sym, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + return GPTQMarlinMoEMethod(quant_args_marlin) + + if isinstance(layer, (LinearBase, ParallelLMHead)): + if use_marlin: + return GPTQMarlinLinearMethod(quant_args_marlin) + else: + return GPTQLinearMethod(quant_args) + + return None + + def apply_ipex_quant_layer(self, layer, prefix: str): + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) + if not self.check_quantized(weight_bits): + if isinstance(layer, (LinearBase, ParallelLMHead)): + return UnquantizedLinearMethod() + else: + return None + from vllm.model_executor.layers.quantization.ipex_quant import ( + IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod) + if isinstance(layer, (LinearBase, ParallelLMHead)): + if "awq" in self.packing_format: + config = IPEXConfig(method="awq", + weight_bits=weight_bits, + group_size=group_size) + return IPEXAWQLinearMethod(config) + elif "gptq" in self.packing_format: + config = IPEXConfig(method="gptq", + weight_bits=weight_bits, + group_size=group_size) + return IPEXGPTQLinearMethod(config) + else: + raise ValueError( + f"ipex backend only supports awq " + f"and gtpq format,but got {self.packing_format}") + else: + return None + + def get_quant_method(self, layer: torch.nn.Module, prefix: str): + if (current_platform.is_cpu() or current_platform.is_xpu() + or self.backend == "ipex"): + return self.apply_ipex_quant_layer(layer, prefix) + if "gptq" in self.packing_format or "gptq" in self.backend: + return self.apply_gptq_quant_layer(layer, prefix) + if "awq" in self.packing_format or "awq" in self.backend: + return self.apply_awq_quant_layer(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py new file mode 100644 index 0000000..e3a0336 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional, Union + +import torch +import os +import torch.nn.functional as F +import vllm.envs as envs +import json +import math +from vllm.platforms import current_platform +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton +from vllm.logger import init_logger +logger = init_logger(__name__) +triton_configs_dict={} + +def get_triton_cache(file_path): + #会将所报错的json文件以字典的形式return出来 + + if os.path.exists(file_path): + with open(file_path, 'r') as file: + cachedata = json.load(file) + + #把所有的cache解析成key:config的形式:[M_N_K]:[config] + for key, value in cachedata.items(): + for sub_key, sub_value in value.items(): + configs_key= f"{sub_key}_{key}" + configs_value={ + 'SPLIT_K': int(sub_value["SPLIT_K"]), + 'BLOCK_SIZE_M': int(sub_value["BLOCK_SIZE_M"]), + 'BLOCK_SIZE_N': int(sub_value["BLOCK_SIZE_N"]), + 'BLOCK_SIZE_K': int(sub_value["BLOCK_SIZE_K"]), + 'GROUP_SIZE_M': int(sub_value["GROUP_SIZE_M"]), + 'num_stages':int(sub_value['num_stages']), + 'num_warps':int(sub_value['num_warps']) + } + if 'num_ldmatrixes' in sub_value: + configs_value["num_ldmatrixes"] = int(sub_value['num_ldmatrixes']) + triton_configs_dict[configs_key]=configs_value + logger.info("%s have loaded!", file_path) + +def default_execution(k,n): + configs_key= f"1_{n}_{k}" + if configs_key in triton_configs_dict: + return + script_dir = os.path.dirname(os.path.abspath(__file__)) + cache_json_file=f"{script_dir}/configs/awq/" + device_name = current_platform.get_device_name().replace(" ", "_") + filename = f"AWQ_{n}_{k}_{device_name}.json" + file_full_path = os.path.join(cache_json_file, filename) + + if os.path.isfile(file_full_path) and file_full_path.endswith(".json"): + # 如果是文件,则添加到列表 + get_triton_cache(file_full_path) + return + + +def getspec_config(M,N,K): + m_config = M + if M > 16: + # 直接计算 2 的幂 + m_config = 1 + while m_config < M: + m_config *= 2 + if f"{m_config}_{N}_{K}" in triton_configs_dict: + return triton_configs_dict[f"{m_config}_{N}_{K}"] + else: + return None + + +class AWQShareWorkSpace: + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs) + cls._instance._initialize() + return cls._instance + + def _initialize(self): + self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize() + self.awqworkshapce = ops.GetAWQShareWorkspace() + +logger = init_logger(__name__) + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + modules_to_not_convert: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})") + + def get_name(self) -> QuantizationMethods: + return "awq" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> list[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, modules_to_not_convert) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: + if isinstance(layer, LinearBase): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + # Lazy import to avoid circular import. + from .awq_marlin import AWQMarlinConfig, AWQMoEMethod + from .moe_wna16 import MoeWNA16Config + from .utils.marlin_utils import check_moe_marlin_supports_layer + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") + config = { + "quant_method": "awq", + "bits": self.weight_bits, + "group_size": self.group_size, + "zero_point": self.zero_point, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + marlin_compatible_config_dict = { + "quant_method": "awq", + "bits": self.weight_bits, + "group_size": self.group_size, + "zero_point": self.zero_point, + "lm_head": False, + "modules_to_not_convert": self.modules_to_not_convert, + } + awq_marlin_config = AWQMarlinConfig.from_config( + marlin_compatible_config_dict) + return AWQMoEMethod(awq_marlin_config) + return None + + +def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + self.awqsingleton= AWQShareWorkSpace() + self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight_loader = extra_weight_attrs.get("weight_loader") + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + zeros_and_scales = GroupQuantScaleParameter(data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + layer.register_parameter("zeros_and_scales", zeros_and_scales) + # 加载triton_config + if envs.VLLM_USE_TRITON_AWQ: + default_execution(input_size_per_partition,output_size_per_partition) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if not envs.VLLM_USE_TRITON_AWQ: + + group_size= self.quant_config.group_size + pad_group=2 + dim_n = layer.scales.data.shape[1] + dim_k = layer.qweight.data.shape[0] + _qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size)) + sz = ops.sz_permute(_sz).reshape(-1,dim_n) + sz = sz.reshape(dim_n,-1) + _qw = _qw.reshape(dim_n,-1) + + if dim_k % 4096==0 and self.use_awq_pad: + zeros_and_scalse_pad = torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() + sz = torch.cat((sz,zeros_and_scalse_pad),dim=1).contiguous() + qweight_pad = torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() + _qw=torch.cat((_qw,qweight_pad),dim=1).contiguous() + + layer.qweight = torch.nn.Parameter(_qw, requires_grad=False) + layer.zeros_and_scales = torch.nn.Parameter(sz, requires_grad=False) + layer.qzeros = None + layer.scales = None + else: + + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.qweight + zeros_and_scales = layer.zeros_and_scales + qzeros = layer.qzeros + scales = layer.scales + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + + m = reshaped_x.shape[0] + k = reshaped_x.shape[-1] + n = qweight.shape[0] + + if self.use_awq_pad: + if k % 4096 == 0: + padding_group=2 + else: + padding_group=0 + else: + padding_group=0 + + if envs.VLLM_USE_TRITON_AWQ: + best_config=getspec_config(m,n,k) + out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config) + out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, )) + else: + out = torch.ops.vllm.awq_gemm(reshaped_x, + qweight, + zeros_and_scales, + m, + n, + k, + self.quant_config.group_size, + padding_group, + self.awqsingleton.awqworkshapce, + self.awqsingleton.awqworkshapcesize) + + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py new file mode 100644 index 0000000..f25de2b --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -0,0 +1,555 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import torch +from torch.nn import Parameter + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + is_layer_skipped_awq) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + check_marlin_supports_layer, check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, marlin_make_workspace_new, + marlin_moe_permute_scales, marlin_permute_scales, + moe_awq_to_marlin_zero_points, verify_marlin_supported, + verify_marlin_supports_shape, + awq_marlin_moe_permute_sz) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class AWQMarlinConfig(QuantizationConfig): + """Config class for AWQ Marlin""" + + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__(self, weight_bits: int, group_size: int, zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any]) -> None: + super().__init__() + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.zero_point = zero_point + self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] + self.full_config = full_config + + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}") + + self.quant_type = self.TYPE_MAP[self.weight_bits] + + verify_marlin_supported(self.quant_type, + group_size=self.group_size, + has_zp=self.zero_point) + + def __repr__(self) -> str: + return (f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "awq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, lm_head_quantized, + modules_to_not_convert, config) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "awq_marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info("Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + # Check if the layer is supported by AWQMarlin. + if not check_marlin_supports_layer(layer, self.group_size): + # logger.warning_once( + # "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 + # prefix, + # ) + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + if is_layer_skipped_awq( + prefix, getattr(self, "modules_to_not_convert", [])): + return UnquantizedFusedMoEMethod(layer.moe_config) + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") + return MoeWNA16Config.from_config( + self.full_config).get_quant_method(layer, prefix) + return AWQMoEMethod(self) + return None + + @classmethod + def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + zero_point = quant_config.get("zero_point") + + # if not current_platform.is_cuda(): + # return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or zero_point is None): + return False + + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], + group_size=group_size, + has_zp=zero_point) + + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ Marlin. + + Args: + quant_config: The AWQ Marlin quantization config. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + # TODO: Update this docs + # Checkpoints are serialized in AutoAWQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace_new(device) + + # Repack weights from AWQ format to marlin format. + marlin_qweight = ops.awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits) + replace_parameter(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size) + replace_parameter(layer, "scales", marlin_scales) + + # Permute zero-points from AWQ format to marlin format. + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits) + replace_parameter(layer, "qzeros", marlin_zp) + + # Not-used + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_awq_marlin_linear( + input=x, + weight=layer.qweight, + weight_scale=layer.scales, + weight_zp=layer.qzeros, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + if self.quant_config.weight_bits != 4: + raise ValueError("AWQMoEMethod only supports 4bit now.") + self.quant_type = scalar_types.uint4 + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP.value, + }) + + w13_qweight = Parameter( + torch.empty(num_experts, + hidden_size, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size_per_partition, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = (intermediate_size_per_partition // + self.quant_config.group_size) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter( + torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + device = layer.w13_qweight.device + layer.workspace = marlin_make_workspace_new(device, 3) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales.to(torch.float16), + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + #replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales.to(torch.float16), + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + #replace_parameter(layer, "w2_scales", marlin_w2_scales) + + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + # replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + # replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + marlin_w13_sz = awq_marlin_moe_permute_sz( + marlin_w13_scales, + marlin_w13_zp, + size_k=layer.w13_scales.shape[1] * self.quant_config.group_size, + size_n=layer.w13_scales.shape[2] + ) + marlin_w2_sz = awq_marlin_moe_permute_sz( + marlin_w2_scales, + marlin_w2_zp, + size_k=layer.w2_scales.shape[1] * self.quant_config.group_size, + size_n=layer.w2_scales.shape[2] + ) + replace_parameter(layer, "w13_scales", marlin_w13_sz) + replace_parameter(layer, "w2_scales", marlin_w2_sz) + + layer.w13_qzeros = None + layer.w2_qzeros = None + torch.cuda.empty_cache() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `AWQMoEMethod` yet.") + + assert activation == "silu", "Only SiLU activation is supported." + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + # quant_type_id=self.quant_type.id, + # apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + workspace=layer.workspace, + num_bits=4 + ) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 0000000..71f44c5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.triton_utils import tl, triton + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + + iweights =tl.join(iweights, iweights).reshape(iweights.shape[:-1] + [2 * iweights.shape[-1]]) + iweights =tl.join(iweights, iweights).reshape(iweights.shape[:-1] + [2 * iweights.shape[-1]]) + iweights =tl.join(iweights, iweights).reshape(iweights.shape[:-1] + [2 * iweights.shape[-1]]) + # iweights = tl.interleave(iweights, iweights) + # iweights = tl.interleave(iweights, iweights) + # iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + # zeros = tl.interleave(zeros, zeros) + # zeros = tl.interleave(zeros, zeros) + # zeros = tl.interleave(zeros, zeros) + zeros =tl.join(zeros, zeros).reshape(zeros.shape[:-1] + [2 * zeros.shape[-1]]) + zeros =tl.join(zeros, zeros).reshape(zeros.shape[:-1] + [2 * zeros.shape[-1]]) + zeros =tl.join(zeros, zeros).reshape(zeros.shape[:-1] + [2 * zeros.shape[-1]]) + + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr,SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + if GROUP_SIZE_M == 1: + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # accumulator_dtype = c_ptr.type.element_ty + BLOCK_SIZE_N_8 = BLOCK_SIZE_N // 8 + N_8 = N // 8 + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=tl.float32) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + shifts = ((tl.arange(0, 2) * 16)[None, :] + + (tl.arange(0, 4) * 4)[:, None]).reshape(1,8) + + # Create the necessary shifts to use to unpack. + # shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts, + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bzn = pid_n * (BLOCK_SIZE_N_8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bzn = offsets_bzn < N_8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N_8) * offsets_k[:, None] + offsets_bzn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bzn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + # b = tl.interleave(b, b) + # b = tl.interleave(b, b) + # b = tl.interleave(b, b) + b =tl.join(b, b).reshape(b.shape[:-1] + [2 * b.shape[-1]]) + b =tl.join(b, b).reshape(b.shape[:-1] + [2 * b.shape[-1]]) + b =tl.join(b, b).reshape(b.shape[:-1] + [2 * b.shape[-1]]) + + # Dequantize b. + offsets_szk = (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + offsets_szk = offsets_szk + (tl.arange(0,BLOCK_SIZE_K) // group_size) + offsets_z = (N_8) * offsets_szk[:, None] + offsets_bzn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_bzn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + # zeros = tl.interleave(zeros, zeros) + # zeros = tl.interleave(zeros, zeros) + # zeros = tl.interleave(zeros, zeros) + + zeros =tl.join(zeros, zeros).reshape(zeros.shape[:-1] + [2 * zeros.shape[-1]]) + zeros =tl.join(zeros, zeros).reshape(zeros.shape[:-1] + [2 * zeros.shape[-1]]) + zeros =tl.join(zeros, zeros).reshape(zeros.shape[:-1] + [2 * zeros.shape[-1]]) + + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N_8) + + c = accumulator.to(c_ptr.type.element_ty) + c_ptrs = c_ptr + N * offsets_am[:, None] + offsets_sn[None, :] + c_mask = masks_am[:, None] & masks_sn[None, :] + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + # tl.store(c_ptrs, c) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton(qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + config=None) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + META['SPLIT_K'], + ) + if config is None: + config= {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8,'SPLIT_K': 8} + if M >256: + #print("INFO:this size not found in json.") + config= {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8,'SPLIT_K': 1} + + result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + **config) + return result diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py new file mode 100644 index 0000000..4a43351 --- /dev/null +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import inspect +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +import torch +from torch import nn + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper +else: + QuantizationMethods = str + + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, *weight_args, + **extra_weight_attrs): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + # Not required functions + def embedding(self, layer: torch.nn.Module, *args, + **kwargs) -> torch.Tensor: + """Gather embeddings in the layer based on indices in the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + + +def method_has_implemented_embedding( + method_class: type[QuantizeMethodBase]) -> bool: + """ + Not all quant methods have embedding implemented, so we need to check that + it exists for our given method. We check this by making sure the function + has been changed from the base implementation. + """ + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", + None) + class_embedding = inspect.getattr_static(method_class, "embedding", None) + + return (class_embedding is not None + and class_embedding is not base_embedding) + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + def __init__(self): + super().__init__() + # mapping is updated by models as they initialize + self.packed_modules_mapping: dict[str, list[str]] = dict() + + @abstractmethod + def get_name(self) -> QuantizationMethods: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> list[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> list[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + + @staticmethod + def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "quantization config.") + + @staticmethod + def get_from_keys_or(config: dict[str, Any], keys: list[str], + default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + + @abstractmethod + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + prefix: The full name of the layer in the state dict + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ + raise NotImplementedError + + def get_cache_scale(self, name: str) -> Optional[str]: + return None + + def apply_vllm_mapper( # noqa: B027 + self, hf_to_vllm_mapper: "WeightsMapper"): + """ + Interface for models to update module names referenced in + quantization configs in order to reflect the vllm model structure + + :param hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure + """ + # TODO (@kylesayrs): add implementations for all subclasses + pass diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py new file mode 100644 index 0000000..aa8eee8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, + BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class BitBLASConfig(QuantizationConfig): + """Config class for BitBLAS. + + Reference: https://github.com/Microsoft/BitBLAS + """ + TORCH_DTYPE = torch.float16 + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # gptq_with_bitblas prefer "quantized implementation" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: Optional[int], + desc_act: Optional[bool], + is_sym: Optional[bool], + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + + storage_dtype = self.STORAGE_DTYPE + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + + self.storage_dtype = storage_dtype + self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + def __repr__(self) -> str: + return (f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @staticmethod + def get_from_keys(config: dict[str, Any], + keys: list[str], + default: Any = None) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + return default + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"], -1) + desc_act = cls.get_from_keys(config, ["desc_act"], False) + is_sym = cls.get_from_keys(config, ["sym"], False) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_bitblas_format: bool + is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" + or hf_quant_cfg.get("is_bitblas_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "bitblas") + + if is_bitblas_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return BitBLASLinearMethod(self) + return None + + +class BitBLASLinearMethod(LinearMethodBase): + """Linear method for BitBLAS. + + Args: + quant_config: The BitBLAS quantization config. + """ + # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS + # Instead of BITBLAS_OPTIMIZE_FEATURES + # If you want to high contiguous batching + # performance + OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: BitBLASConfig): + self.quant_config = quant_config + + def create_weights_gptq( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized + weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_size_per_partition: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size in + `quant_config`. + """ + del input_size, output_size # Unused arguments. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype not in self.quant_config.get_supported_act_dtypes(): + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + group_size = self.quant_config.group_size + if group_size is None: + group_size = -1 + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if (group_size != -1 and input_size_per_partition % group_size != 0): + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({group_size}).") + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Initialize quantized weights with dimensions + # Quantized 4Bit weights packed. + qweight = PackedvLLMParameter( + data=torch.empty( + self.bitblas_matmul.retrieve_weight_shape(), + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b else None), + weight_loader=weight_loader, + ) + + # Compute the number of input groups for channel-wise quantization. + input_groups = (1 if group_size == -1 else input_size_per_partition // + group_size) + + # Initialize scales and zeros for the quantized weights. + weight_scale_args = { + "data": + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + if self.quant_config.zeros_mode == "quantized": + zeros = PackedvLLMParameter( + data=torch.empty( + input_groups, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + else: + zeros = BasevLLMParameter( + torch.empty(output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype), + weight_loader=weight_loader, + ) + # Set attributes to indicate how scales and zeros are applied. + set_weight_attrs(zeros, { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0, + }) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("scales", scales) + layer.register_parameter("zeros", zeros) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.quant_method == "gptq": + return self.create_weights_gptq(layer, input_size_per_partition, + output_partition_sizes, input_size, + output_size, params_dtype, + **extra_weight_attrs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + out_dtype="float16", + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + + with_scaling = False + with_zeros = False + group_size = self.quant_config.group_size + zeros_mode = self.quant_config.zeros_mode + if self.quant_config.quant_method == "gptq": + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if self.quant_config.is_sym: + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + matmul_config = MatmulConfig( + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=out_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.STORAGE_DTYPE, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...") + logger.info(TUNING_MESSAGE) + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNED_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNED_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created." + logger.info(_message) + else: + _message = ( + f"BitBLAS Operator {config} found in global_operator_cache.") + logger.info(_message) + return bitblas_matmul + + def apply_gptq( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.zeros + + x_2d = x.view(-1, x.shape[-1]) + + if self.quant_config.is_sym: + output_2d = self.bitblas_matmul(x_2d, qweight, scales) + else: + output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output + + def apply( + self, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + if self.quant_config.quant_method == "gptq": + return self.apply_gptq(*args, **kwargs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py new file mode 100644 index 0000000..1ed3ef8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.utils import direct_register_custom_op + + +class BitsAndBytesConfig(QuantizationConfig): + """Config class for BitsAndBytes Quantization. + + Reference: https://arxiv.org/abs/2305.14314 + """ + + def __init__( + self, + load_in_8bit: bool = False, + load_in_4bit: bool = True, + bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_storage: str = "uint8", + bnb_4bit_quant_type: str = "fp4", + bnb_4bit_use_double_quant: bool = False, + llm_int8_enable_fp32_cpu_offload: bool = False, + llm_int8_has_fp16_weight: bool = False, + llm_int8_skip_modules: Optional[list[str]] = None, + llm_int8_threshold: float = 6.0, + ) -> None: + super().__init__() + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.llm_int8_skip_modules = llm_int8_skip_modules or [] + self.llm_int8_threshold = llm_int8_threshold + + if self.bnb_4bit_quant_storage not in ["uint8"]: + raise ValueError("Unsupported bnb_4bit_quant_storage: " + f"{self.bnb_4bit_quant_storage}") + + def __repr__(self) -> str: + return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " + f"load_in_4bit={self.load_in_4bit}, " + f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " + f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " + f"llm_int8_skip_modules={self.llm_int8_skip_modules})") + + @classmethod + def get_name(self) -> QuantizationMethods: + return "bitsandbytes" + + @classmethod + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.float32, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig": + + def get_safe_value(config, keys, default_value=None): + try: + value = cls.get_from_keys(config, keys) + return value if value is not None else default_value + except ValueError: + return default_value + + load_in_8bit = get_safe_value(config, ["load_in_8bit"], + default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], + default_value=True) + bnb_4bit_compute_dtype = get_safe_value(config, + ["bnb_4bit_compute_dtype"], + default_value="float32") + bnb_4bit_quant_storage = get_safe_value(config, + ["bnb_4bit_quant_storage"], + default_value="uint8") + bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], + default_value="fp4") + bnb_4bit_use_double_quant = get_safe_value( + config, ["bnb_4bit_use_double_quant"], default_value=False) + llm_int8_enable_fp32_cpu_offload = get_safe_value( + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) + llm_int8_has_fp16_weight = get_safe_value(config, + ["llm_int8_has_fp16_weight"], + default_value=False) + llm_int8_skip_modules = get_safe_value(config, + ["llm_int8_skip_modules"], + default_value=[]) + llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], + default_value=6.0) + + return cls( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_storage=bnb_4bit_quant_storage, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + llm_int8_skip_modules=llm_int8_skip_modules, + llm_int8_threshold=llm_int8_threshold) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): + return UnquantizedLinearMethod() + return BitsAndBytesLinearMethod(self) + return None + + +def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): + # Split the prefix into its dot-separated components + components = prefix.split('.') + + # Check if any of the skip modules exactly matches any component + substr_check = any(module_name in components + for module_name in llm_int8_skip_modules) + + # Allow certain layers to not be quantized + set_components = set(".".join(components[:i + 1]) + for i in range(len(components))) + set_llm_int8_skip_modules = set(llm_int8_skip_modules) + prefix_check = len(set_llm_int8_skip_modules & set_components) != 0 + + return substr_check or prefix_check + + +class BitsAndBytesLinearMethod(LinearMethodBase): + """Linear method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.46.1": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer.") from err + + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + from bitsandbytes.nn import Int8Params + + def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + def create_qweight_for_8bit(): + qweight = Int8Params( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": 1, + "use_bitsandbytes_8bit": True, + "generation": 0 + }) + return qweight + + def create_qweight_for_4bit(): + quant_ratio = calculate_quant_ratio(params_dtype) + + total_size = input_size_per_partition * sum(output_partition_sizes) + if total_size % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape.") + + qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, + 1, + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True + }) + return qweight + + if self.quant_config.load_in_8bit: + qweight = create_qweight_for_8bit() + else: + qweight = create_qweight_for_4bit() + # Enable parameters to have the same name as in the BNB + # checkpoint format. + layer.register_parameter("weight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.quant_config.load_in_8bit: + return self._apply_8bit_weight(layer, x, bias) + else: + return self._apply_4bit_weight(layer, x, bias) + + def _apply_8bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import MatmulLtState, matmul + + original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True + bf_x = x.to(torch.bfloat16) + + qweight = layer.weight + offsets = qweight.bnb_shard_offsets + quant_states = qweight.bnb_quant_state + matmul_states = qweight.matmul_state + generation = qweight.generation + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.float16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + + # in profile_run or the first generation of inference, + # create new matmul_states + if generation == 0 or generation == 1: + matmul_states[i] = MatmulLtState() + matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].SCB = quant_states[i].to(x.device) + matmul_states[i].threshold = ( + self.quant_config.llm_int8_threshold) + matmul_states[i].has_fp16_weights = ( + self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].is_training = False + if matmul_states[i].threshold > 0.0 and not matmul_states[ + i].has_fp16_weights: + matmul_states[i].use_pool = True + + new_x = bf_x.unsqueeze(0) + + out[:, current_index:current_index + output_size] = matmul( + new_x, + qweight[offsets[i]:offsets[i + 1]], + state=matmul_states[i]) + + current_index += output_size + + # only update the matmul_states if it is not profile_run + if (generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None): + del matmul_states[i].CB + qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + + out = out.to(original_type) + + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + + if bias is not None: + out += bias + + qweight.generation += 1 + + return out + + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True + bf_x = x.to(torch.bfloat16) + + qweight = layer.weight + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.bfloat16, + device=x.device) + apply_bnb_4bit(bf_x, qweight, offsets, out) + out = out.to(original_type) + + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + + if bias is not None: + out += bias + + return out + + +def _apply_bnb_4bit( + x: torch.Tensor, + weight: torch.Tensor, + offsets: torch.Tensor, + out: torch.Tensor, +) -> None: + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + quant_states = weight.bnb_quant_state + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + # It is more efficient to use out kwarg like + # matmul_4bit(..., out = ...). Infeasible now due to the bug + # https://github.com/TimDettmers/bitsandbytes/issues/1235. + # Need to change after the bug is fixed. + out[:, current_index:current_index + output_size] = matmul_4bit( + x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + current_index += output_size + + +def _apply_bnb_4bit_fake( + x: torch.Tensor, + weight: torch.Tensor, + offsets: torch.Tensor, + out: torch.Tensor, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + ) + apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit + +except AttributeError as error: + raise error diff --git a/vllm/model_executor/layers/quantization/blockwise_int8.py b/vllm/model_executor/layers/quantization/blockwise_int8.py new file mode 100644 index 0000000..8cb3ab4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/blockwise_int8.py @@ -0,0 +1,518 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/3730 + +import logging +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn import Module +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) + +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) + +from lmslim.layers.gemm.int8_utils import ( + apply_w8a8_block_int8_linear) + +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import W8a8GetCacheJSON + +import os +from vllm import _custom_ops as ops + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = logging.getLogger(__name__) + + +class BlockInt8Config(QuantizationConfig): + """Config class for INT8.""" + + def __init__( + self, + is_checkpoint_int8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + weight_block_size: Optional[List[int]] = None, + ) -> None: + self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized + if is_checkpoint_int8_serialized: + logger.warning( + "Detected int8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError("Unsupported activation scheme" + f" {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_int8_serialized: + raise ValueError( + f"The block-wise quantization only supports " + "int8-serialized checkpoint for now." + ) + if len(weight_block_size) != 2: + raise ValueError( + f"The quantization block size of weight must have 2 " + "dimensions, but got {len(weight_block_size)} dimensions." + ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic " + "activation scheme for now, but got " + "{activation_scheme} activation scheme." + ) + self.weight_block_size = weight_block_size + + @classmethod + def get_name(cls) -> str: + return "blockwise_int8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_int8_serialized = "int8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, + ["weight_block_size"], None) + return cls( + is_checkpoint_int8_serialized=is_checkpoint_int8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return BlockInt8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return BlockInt8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class BlockInt8LinearMethod(LinearMethodBase): + """Linear method for INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic activation scale. + Limitations: + Only support block-wise int8 quantization and int8 checkpoint + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: BlockInt8Config): + self.quant_config = quant_config + self.tritonsingleton= W8a8GetCacheJSON() + self.block_size=self.quant_config.weight_block_size + + assert self.quant_config.weight_block_size is not None + assert self.quant_config.is_checkpoint_int8_serialized + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: Optional[List[int]], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # assert output_partition_sizes is not None, ( + # "output_partition_sizes must be provided for quantization") + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + tp_size = get_tensor_model_parallel_world_size() + + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if tp_size > 1 and input_size // input_size_per_partition == tp_size: + if input_size_per_partition % block_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # Required by collum parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len( + output_partition_sizes + ) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = ( + torch.int8 + if self.quant_config.is_checkpoint_int8_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale_inv", scale) + + # INPUT ACTIVATION SCALE + assert self.quant_config.activation_scheme == "dynamic" + layer.register_parameter("input_scale", None) + + + def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + # Use torch Parameter to avoid cuda graph capturing issue + n=layer.weight.shape[0] + k=layer.weight.shape[1] + + if [n,k] not in self.tritonsingleton.weight_shapes: + self.tritonsingleton.weight_shapes.append([n,k]) + json_file=self.tritonsingleton.get_blockint8json_name(n,k,self.block_size[0],self.block_size[1]) + configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,self.block_size[0],self.block_size[1]) + + if configs_dict: + self.tritonsingleton.triton_json_dict.update(configs_dict) + + for key, value in configs_dict.items(): + m=int(key.split('_')[0]) + + ops.triton_blockint8_gemm_helper(m=m,n=n,k=k,block_size=self.block_size,use_bias=False,out_dtype=torch.bfloat16,device=layer.weight.device,best_config=value) + + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + M=x.shape[0] + K=x.shape[1] + N=layer.weight.shape[0] + + #Get the best config options + if len(self.tritonsingleton.triton_json_dict)==0: + config=None + + elif f"1_{N}_{K}_block[{self.block_size[0]},{self.block_size[1]}]" in self.tritonsingleton.triton_json_dict: + if M<=16: + m_=M + elif M<=64: + m_= (M + 3) & -4 #取值到最近的4的倍数 + elif M<=160: + m_=(M + 7) & -8 + + elif M<200: #256 + m_=160 + elif M<480: #512 + m_=256 + elif M<960: #1024 + m_=512 + elif M<2048: + m_=1024 + elif M<4096: + m_=2048 + elif M<6000: + m_=4096 + else: + m_=8192 + + config=self.tritonsingleton.triton_json_dict[f"{m_}_{N}_{K}_block[{self.block_size[0]},{self.block_size[1]}]"] + + else: + config=None + + return apply_w8a8_block_int8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=None, + bias=bias, + config=config + ) + +class BlockInt8MoEMethod: + """MoE method for INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic activation scale. + + Limitations: + Only support block-wise int8 quantization and int8 checkpoint + + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + assert self.quant_config.weight_block_size is not None + assert self.quant_config.is_checkpoint_int8_serialized + self.tritonsingleton= W8a8GetCacheJSON() + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from vllm.model_executor.layers.fused_moe import FusedMoeWeightScaleSupported + + if self.quant_config.is_checkpoint_int8_serialized: + params_dtype = torch.int8 + tp_size = get_tensor_model_parallel_world_size() + + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert self.quant_config.activation_scheme == "dynamic" + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + # warmup and get moe block-int8 config + E=layer.w13_weight.shape[0] + N1=layer.w13_weight.shape[1] + N2=layer.w2_weight.shape[1] + K=N1//2 + if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: + self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) + + TOPK= self.tritonsingleton.topk + block_size=self.quant_config.weight_block_size + + json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,block_size,) + configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) + + #warmup + if configs_dict: + self.tritonsingleton.triton_moejson_dict.update(configs_dict) + + #print("*************self.tritonsingleton:",self.tritonsingleton) + #生成模型配置文件 + self.tritonsingleton.gen_model_json(block_size) + + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **_ + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `MoeBlockInt8Method` yet.") + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate + ) + + # Expert fusion with INT8 quantization + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int8_w8a8=True, + activation=activation, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + w1_scale=(layer.w13_weight_scale_inv), + w2_scale=(layer.w2_weight_scale_inv), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + use_nn_moe=use_nn_moe + ) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000..20ba078 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,724 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import suppress +from typing import TYPE_CHECKING, Any, Literal, Optional, cast + +import torch +from compressed_tensors.config import (CompressionFormat, + SparsityCompressionConfig, + SparsityStructure) +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) +from pydantic import BaseModel + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) + +from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod +from vllm.model_executor.layers.quantization import QuantizationMethods + +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, + CompressedTensorsScheme, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target, is_activation_quantization_format, + should_ignore_layer) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) +from vllm.platforms import current_platform +from vllm.utils import W8a8GetCacheJSON + +import os +from vllm import _custom_ops as ops + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsLinearMethod"] + +SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__( + self, + target_scheme_map: dict[str, Any], + ignore: list[str], + quant_format: str, + sparsity_scheme_map: dict[str, SparsityCompressionConfig], + sparsity_ignore_list: list[str], + kv_cache_scheme: Optional[dict[str, Any]] = None, + config: Optional[dict[str, Any]] = None, + ): + super().__init__() + self.ignore = ignore + self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map + self.kv_cache_scheme = kv_cache_scheme + self.sparsity_scheme_map = sparsity_scheme_map + self.sparsity_ignore_list = sparsity_ignore_list + self.config = config + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> QuantizationMethods: + return "compressed-tensors" + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + self.target_scheme_map = hf_to_vllm_mapper.apply_dict( + self.target_scheme_map) + self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) + self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( + self.sparsity_scheme_map) + self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( + self.sparsity_ignore_list) + if self.kv_cache_scheme is not None: + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( + self.kv_cache_scheme) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + # TODO (@robertgshaw2): support module names + if should_ignore_layer(prefix, + ignore=self.ignore, + fused_mapping=self.packed_modules_mapping): + return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + if scheme is None: + return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod() + layer.scheme = scheme + return CompressedTensorsLinearMethod(self) + if isinstance(layer, Attention): + return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod.get_moe_method(self, layer) + return None + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": + ignore: list[str] = cast(list[str], config.get("ignore", [])) + quant_format = cast(str, config.get("format")) + target_scheme_map = cls._quantization_scheme_map_from_config( + config=config) + sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( + config=config) + + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + sparsity_scheme_map=sparsity_scheme_map, + sparsity_ignore_list=sparsity_ignore_list, + config=config, + ) + + @classmethod + def _parse_sparsity_config( + cls, config: dict[str, Any] + ) -> tuple[dict[str, SparsityCompressionConfig], list[str]]: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A tuple with two elements + 1. A dictionary mapping target layer names to their corresponding + sparsity_config + 2. A list of layer names to ignore for sparsity + """ + if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)): + return dict(), [] + + sparsity_config = SparsityCompressionConfig.model_validate( + sparsity_config) + sparse_scheme_map: dict[str, SparsityCompressionConfig] = { + target: sparsity_config + for target in sparsity_config.targets or list() + } + sparsity_ignore_list = sparsity_config.ignore or list() + return sparse_scheme_map, sparsity_ignore_list + + @classmethod + def _quantization_scheme_map_from_config( + cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A dictionary mapping target layer names to their corresponding + quantization_args for weights and input activations + """ + target_scheme_map: dict[str, Any] = dict() + quant_format = cast(str, config.get("format")) + + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. + + config_groups = config.get("config_groups", dict()) + for _, quant_config in config_groups.items(): + targets = quant_config.get("targets") + for target in targets: + target_scheme_map[target] = {} + target_scheme_map[target][ + "weights"] = QuantizationArgs.model_validate( + quant_config.get("weights")) + + target_scheme_map[target]["input_activations"] = None + if is_activation_quantization_format(quant_format): + input_activations = quant_config.get("input_activations") + # The only case where we have activation quant supported + # but no input_activations provided in the config + # should be w8a16fp8 w8a16fp8 can also run for cases where + # there is an input_quant but it is ignored + if not input_activations: + assert target_scheme_map[target][ + "weights"].type == QuantizationType.FLOAT + else: + target_scheme_map[target][ + "input_activations"] = QuantizationArgs.model_validate( # noqa: E501 + quant_config.get("input_activations")) + return target_scheme_map + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def _check_scheme_supported(self, + min_capability: int, + error: bool = True, + match_exact: bool = False) -> bool: + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if match_exact: + supported = capability == min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + "the current GPU. Required capability: ", + f"{min_capability}. Current capability: {capability}.") + else: + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False + + def _is_fp4a4_nvfp4(self, weight_quant: BaseModel, input_quant: BaseModel): + + if weight_quant is None or input_quant is None: + return False + + is_tensor_group_quant = (weight_quant.strategy + == QuantizationStrategy.TENSOR_GROUP.value + and input_quant.strategy + == QuantizationStrategy.TENSOR_GROUP.value) + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + is_group_size_16 = (weight_quant.group_size == 16 + and input_quant.group_size == 16) + is_float_type = (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT.value) + is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 + + return (is_tensor_group_quant and is_float_type and is_4_bits + and is_group_size_16 and is_symmetric) + + def _is_fp4a16_nvfp4(self, weight_quant: BaseModel, + input_quant: BaseModel): + + is_weight_only = weight_quant is not None and input_quant is None + is_tensor_group_quant = ( + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value) + is_symmetric = weight_quant.symmetric + + is_group_size_16 = weight_quant.group_size == 16 + is_float_type = weight_quant.type == QuantizationType.FLOAT + is_4_bits = weight_quant.num_bits == 4 + + return (is_weight_only and is_tensor_group_quant and is_float_type + and is_4_bits and is_group_size_16 and is_symmetric) + + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) + is_static = not weight_quant.dynamic and not input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_tensor and weight_quant.symmetric and is_static + + def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_token and weight_quant.symmetric and is_dynamic + + def _is_fp8_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights and activations quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported. + is_floating_point = (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT) + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_floating_point and is_symmetric_weight and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.dynamic: + return True + + # Confirm activation scheme is supported. + is_symmetric_activation = input_quant.symmetric + is_per_tensor_activation = ( + input_quant.strategy == QuantizationStrategy.TENSOR) + return is_symmetric_activation and is_per_tensor_activation + + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) + + def _is_fp8_w8a16(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights quantized. + if weight_quant is None: + return False + + # Confirm we have floating points. + if weight_quant.type != QuantizationType.FLOAT: + return False + + # Confirm weight scheme is supported. + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_symmetric_weight and is_static_weight # noqa: SIM103 + and is_per_tensor_or_channel_weight): + return False + + # All conditions satisfied. + return True + + def _is_wNa16_group_channel(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + input_quant_none = input_quant is None + is_channel_group = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_static = not weight_quant.dynamic + + return (is_channel_group and input_quant_none and is_static) + + def _get_scheme_from_parts( + self, weight_quant: BaseModel, + input_quant: BaseModel) -> "CompressedTensorsScheme": + + # Detect If Mixed Precision + if self._is_fp4a16_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A16Fp4() + + if self._is_wNa16_group_channel(weight_quant, input_quant): + if (self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + assert weight_quant.symmetric + return CompressedTensorsW4A16Sparse24( + strategy=weight_quant.strategy, + num_bits=weight_quant.num_bits, + group_size=weight_quant.group_size) + if (self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + return CompressedTensorsWNA16( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + + if is_activation_quantization_format(self.quant_format): + if self._is_fp4a4_nvfp4(weight_quant, input_quant): + if cutlass_fp4_supported( + ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: + return CompressedTensorsW4A4Fp4() + else: + logger.warning_once( + "Current platform does not support cutlass NVFP4." + " Running CompressedTensorsW4A16Fp4.") + return CompressedTensorsW4A16Fp4( + has_input_global_scale=True) + + if self._is_fp8_w8a8(weight_quant, input_quant): + is_fp8_w8a8_supported = self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + return CompressedTensorsW8A8Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) + else: + # note: input_quant will be present for converted models; + # will be ignored during inference post loading + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=not input_quant.dynamic) + + # note: input_quant can be None + if self._is_fp8_w8a16(weight_quant, input_quant): + is_static_input_scheme = (input_quant + and not input_quant.dynamic) + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=is_static_input_scheme) + + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True, + input_symmetric=input_quant.symmetric) + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric) + + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") + + def get_scheme(self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for inference. + """ + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + # TODO (@robertgshaw): add compressed-tensors as dep + # so we do not have to re-write these functions + # need to make accelerate optional in ct to do this + + # Will be empty for models with only sparsity + weight_quant = input_quant = None + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping) + + scheme_dict = self.target_scheme_map[matched_target] + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + # Find the sparsity scheme of the layer + # assume that fused layers inerhit first component's sparsity scheme + sparsity_targets = (self.sparsity_scheme_map.keys() - + set(self.sparsity_ignore_list)) + sparsity_scheme: Optional[SparsityCompressionConfig] = None + with suppress(ValueError): + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=sparsity_targets, + fused_mapping=self.packed_modules_mapping) + sparsity_scheme = self.sparsity_scheme_map[matched_target] + + if self.supports_cutlass_24(weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme): + # Have a valid sparsity scheme + # Validate layer is supported by Cutlass 2:4 Kernel + model_compression_config = (None if sparsity_scheme is None + or sparsity_scheme.format == "dense" + else self.config) + + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=model_compression_config, + ) + elif weight_quant is None: + logger.warning_once("Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod") + return None + + else: + # Find the quant_scheme + scheme = self._get_scheme_from_parts( # type: ignore + weight_quant=weight_quant, + input_quant=input_quant, + ) + + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, + layer_name) + return scheme + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + # If no matches, return None + return None + + @staticmethod + def supports_cutlass_24( + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig] = None + ) -> bool: + """ + Check if the layer is supported by the Cutlass 2:4 Kernel + Conditions: + - Overarching condition: Sparsity Structure is 2:4 + - Unquantized cases are supported + - Weight only quantization is not-supported + - Supported weight quantization strategies are TENSOR and CHANNEL + - Supported input quantization strategies are TENSOR and TOKEN + - Only 8 bit quantization is supported + + :return: True if the layer is supported by the Cutlass 2:4 Kernel + False otherwise + """ + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == + SparsityStructure.TWO_FOUR.value) + + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value + } + + is_valid_sparsity = (is_valid_sparsity_structure + and sparsity_scheme.format in valid_compressors) + + if not is_valid_sparsity: + return False + + # Unquantized cases are supported + if weight_quant is None and input_quant is None: + return True + + # Weight only quantization is not-supported + if weight_quant is not None and input_quant is None: + return False + + supported_weight_quant_strategies = [ + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.CHANNEL.value + ] + + assert weight_quant is not None + assert input_quant is not None + if weight_quant.strategy not in supported_weight_quant_strategies: + return False + + supported_input_quant_strategies = [ + QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value + ] + + if input_quant.strategy not in supported_input_quant_strategies: + return False + + return weight_quant.num_bits == input_quant.num_bits == 8 + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + self.tritonsingleton= W8a8GetCacheJSON() + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + n=layer.weight.shape[0] + k=layer.weight.shape[1] + + if self.w8a8_strategy==1: + if [n,k] not in self.tritonsingleton.weight_shapes: + self.tritonsingleton.weight_shapes.append([n,k]) + json_file=self.tritonsingleton.get_w8a8json_name(n,k) + configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k) + + if configs_dict: + self.tritonsingleton.triton_json_dict.update(configs_dict) + + for key, value in configs_dict.items(): + m=int(key.split('_')[0]) + ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) + else: + weight_data=layer.weight.data + _weight=weight_data.T.contiguous().reshape(n,-1) + layer.weight.data=_weight + + self.tritonsingleton.gen_model_json() + layer.scheme.process_weights_after_loading(layer) + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from compressed-tensors + checkpoints. + """ + + def __init__(self, quant_config: CompressedTensorsConfig): + self.validate_kv_cache_scheme(quant_config.kv_cache_scheme) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]): + """ + Validator for the kv cache scheme. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_scheme: the compressed-tensors kv cache scheme + """ + if kv_cache_scheme is None: + return + + type_ = kv_cache_scheme.get("type") + num_bits = kv_cache_scheme.get("num_bits") + + if type_ != "float" and num_bits != 8: + raise NotImplementedError( + "Currently supported kv cache quantization is " + "num_bits=8, type=float, however " + f"received num_bits={num_bits}, type={type_}") + + strategy = kv_cache_scheme.get("strategy") + if strategy != "tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for compressed-tensors KV cache. " + f"Expected strategy: tensor, found strategy: {strategy}") + + is_symmetric = kv_cache_scheme.get("symmetric") + if not is_symmetric: + raise NotImplementedError( + "Only support symmetric scaling factor " + "for compressed-tensors KV cache. " + f"However found symmetric: {is_symmetric}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 0000000..66ad18b --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,1663 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from enum import Enum +from typing import Callable, Optional + +import torch +from compressed_tensors import CompressionFormat +from compressed_tensors.quantization import (ActivationOrdering, + QuantizationStrategy) + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_moe_marlin_supports_layer, marlin_make_workspace_new, + marlin_moe_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_moe_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + cutlass_fp4_supported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.utils import W8a8GetCacheJSON + +logger = init_logger(__name__) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = [ + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsW8A8Fp8MoECutlassMethod", + "CompressedTensorsW8A8Int8MoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod" +] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + layer: torch.nn.Module, + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + # group_size=None means channelwise + group_size = weight_quant.group_size or -1 + # Prefer to use the MarlinMoE kernel when it is supported. + if not check_moe_marlin_supports_layer(layer, group_size): + if (weight_quant.strategy in QuantizationStrategy.GROUP and + weight_quant.actorder in (ActivationOrdering.GROUP, + ActivationOrdering.DYNAMIC)): + raise ValueError( + "WNA16MoE is not supported with actorder=group/dynamic." + ) + logger.info_once("Using CompressedTensorsWNA16MoEMethod") + return CompressedTensorsWNA16MoEMethod(quant_config) + else: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") + return CompressedTensorsWNA16MarlinMoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + return CompressedTensorsW4A4MoeMethod() + elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8MoEMethod(quant_config) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + + +class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): + + def __init__(self): + self.use_marlin = not cutlass_fp4_supported() + self.group_size = 16 + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter(torch.empty( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # From packed to weight + layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data, + requires_grad=False) + + layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, + requires_grad=False) + + if not torch.allclose(layer.w13_weight_global_scale[:, 0], + layer.w13_weight_global_scale[:, 1]): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected.") + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False) + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + return + + # swizzle weight scales + layer.w13_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w13_weight_scale), + requires_grad=False) + + layer.w2_blockscale_swizzled = torch.nn.Parameter( + self.swizzle_blockscale(layer.w2_weight_scale), + requires_grad=False) + + # w13 + w13_input_global_scale = layer.w13_input_global_scale.max( + dim=1).values.to(torch.float32) + + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False) + + # w2 + layer.g2_alphas = torch.nn.Parameter( + ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( + torch.float32), + requires_grad=False) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (layer.w2_input_global_scale), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsW4A4MoeMethod` yet.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + if self.use_marlin: + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for CompressedTensorsW4A4MoeMethod.") + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod.") + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + self.topk_indices_dtype = None + + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): + raise ValueError( + "For FP8 Fused MoE layers, we require per tensor " + "or channelwise, dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + self.tritonsingleton= W8a8GetCacheJSON() + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E=layer.w13_weight.shape[0] + N1=layer.w13_weight.shape[1] + N2=layer.w2_weight.shape[1] + K=layer.w2_weight.shape[2] + if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: + self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) + + TOPK= self.tritonsingleton.topk + + json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK) + configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) + + #warmup + if configs_dict: + self.tritonsingleton.triton_moejson_dict.update(configs_dict) + + #生成模型配置文件 + #self.tritonsingleton.gen_model_json(block_size) + return + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # For Per-TENSOR case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + # Property to determine if AITER is used + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 + rocm_aiter_fused_experts, shuffle_weights) + + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts + elif self.use_marlin: + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + self.fused_experts_func = None + else: + from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + from vllm.model_executor.layers.fused_moe import TritonExperts + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) + + assert not self.rocm_aiter_moe_enabled and not self.use_marlin + + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) + + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( + ) + assert max_num_tokens_per_rank is not None + + return BatchedTritonExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN), + ) + else: + return TritonExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=( + self.input_quant.strategy == QuantizationStrategy.TOKEN), + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoEMethod` yet.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + ) + + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + expert_map=expert_map) + if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert self.fused_experts_func is not None + + return self.fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): + raise ValueError( + "For FP8 Fused MoE layers, we require per tensor " + "or channelwise, dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + self.topk_indices_dtype = None + self.fused_experts = cutlass_moe_fp8 # type: ignore + self.disable_expert_map = False + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # For Per-TENSOR case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 + + use_batched_format = (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts) + + num_dispatchers = prepare_finalize.num_dispatchers() + + num_experts = (moe.num_local_experts + if use_batched_format else moe.num_experts) + + logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) + + experts = CutlassExpertsFp8( + num_experts, + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + num_dispatchers=num_dispatchers, + use_batched_format=use_batched_format, + ) + + self.disable_expert_map = (num_dispatchers > 1 + or not experts.supports_expert_map()) + + return experts + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + ) + + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + +class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales.") + self.tritonsingleton= W8a8GetCacheJSON() + + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + hidden_size, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E=layer.w13_weight.shape[0] + N1=layer.w13_weight.shape[1] + N2=layer.w2_weight.shape[1] + K=layer.w2_weight.shape[2] + if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: + self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) + + TOPK= self.tritonsingleton.topk + + json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK) + configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) + + #warmup + if configs_dict: + self.tritonsingleton.triton_moejson_dict.update(configs_dict) + + pass + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Int8MoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_int8_w8a8=True, + per_channel_quant=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=False) + + +class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + w2_scales_size = (intermediate_size_full + if load_full_w2 else intermediate_size_per_partition) + + self.is_k_full = (not self.actorder) or ( + intermediate_size_per_partition == intermediate_size_full) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_weight_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_weight_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][ + w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_weight_scale, + size_k=layer.w13_weight_packed.shape[2], + size_n=layer.w13_weight_scale.shape[2], + group_size=self.group_size, + ) + replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_weight_scale, + size_k=layer.w2_weight_scale.shape[1] * + (self.group_size if self.group_size != -1 else self.packed_factor), + size_n=layer.w2_weight_scale.shape[2], + group_size=self.group_size, + ) + replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + + layer.workspace = marlin_make_workspace_new(device, 4) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsWNA16MarlinMoEMethod` yet.") + + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_type.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + workspace=layer.workspace, + is_k_full=self.is_k_full) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + # channelwise is not supported by this kernel + assert config.strategy == "group" + self.group_size = config.group_size + # grouped actorder isn't supported by this kernel + assert config.actorder != "group" + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_scales_size = intermediate_size_per_partition + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": False}) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Reconfigure packed weights and scales to match moe_wna16 format + layer.w13_weight_packed = torch.nn.Parameter( + layer.w13_weight_packed.transpose(1, 2).contiguous().view( + torch.uint8), + requires_grad=False) + layer.w2_weight_packed = torch.nn.Parameter( + layer.w2_weight_packed.transpose(1, + 2).contiguous().view(torch.uint8), + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + layer.w13_weight_scale.transpose(1, 2).contiguous(), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + layer.w2_weight_scale.transpose(1, 2).contiguous(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsWNA16MoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int4_w4a16=self.num_bits == 4, + use_int8_w8a16=self.num_bits == 8, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_zp=None, + w2_zp=None, + block_shape=[0, self.group_size]) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000..6e4e75d --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 +from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 +from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 +from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 +from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, + CompressedTensorsWNA16) + +from .compressed_tensors_24 import CompressedTensors24 # isort: skip + +__all__ = [ + "CompressedTensorsScheme", "CompressedTensorsWNA16", + "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", + "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", + "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", + "CompressedTensors24", "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A4Fp4" +] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py new file mode 100644 index 0000000..30ed55a --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import torch +from compressed_tensors import CompressionFormat, ModelCompressor +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) +from compressed_tensors.utils import combine_shards + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise, sparse_cutlass_supported) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensors24"] + + +class CompressedTensors24(CompressedTensorsScheme): + + def __init__( + self, + quantized: bool = False, + weight_quant: Optional[QuantizationArgs] = None, + input_quant: Optional[QuantizationArgs] = None, + model_compression_config: Optional[dict[str, Any]] = None, + ): + self.quantized = quantized + self.weight_quant = weight_quant + self.input_quant = input_quant + self.model_compressor = ( + ModelCompressor.from_compression_config(model_compression_config) + if model_compression_config is not None else None) + self.do_sparse_decompress = ( + self.model_compressor is not None + and self.model_compressor.sparsity_config.format + == CompressionFormat.sparse_24_bitmask.value) + + @classmethod + def get_min_capability(cls) -> int: + # Only cutlass 3.x kernels are implemented so far + return 90 + + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + if not sparse_cutlass_supported(): + raise ValueError( + "Sparse CUTLASS not supported. vLLM must be built with " + "CUDA 12.2 or later to use this feature") + + layer.logical_widths = output_partition_sizes + layer.input_size = input_size + layer.input_size_per_partition = input_size_per_partition + self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype) + + # parameter to store uncompressed weight + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + if self.do_sparse_decompress: + assert all(partition_size % 8 == 0 + for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " + "2:4 sparse compressed models" + + shape = BasevLLMParameter( + data=torch.empty(2, 1, dtype=torch.int64), + weight_loader=weight_loader, + ) + compressed_weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=self.weights_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + bitmask = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 8, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("shape", shape) + layer.register_parameter("compressed", compressed_weight) + layer.register_parameter("bitmask", bitmask) + + # Check if quantized, not just 2:4 Sparse + if self.quantized: + if (self.weight_quant and self.weight_quant.strategy + == QuantizationStrategy.CHANNEL.value): + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert (self.weight_quant and self.weight_quant.strategy + == QuantizationStrategy.TENSOR.value) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + # input quant will be non-none + if self.input_quant and not self.input_quant.dynamic: + # register input quant scale + assert (self.input_quant.strategy == + QuantizationStrategy.TENSOR.value) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + else: + # for sparse-only, pass in 1 for weight/input scales + weight_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + input_scale = torch.nn.Parameter(data=torch.ones( + 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("input_scale", input_scale) + layer.register_parameter("weight_scale", weight_scale) + + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """ + Compress weights after loading. Store compressed weight and meta + tensor + + :post-condition: layer.w_compressed and layer.meta are + set to the compressed weight and meta tensor in the + format expected by the Cutlass kernels + :param layer: The layer with the weights to be processed + + """ + if self.do_sparse_decompress: + layer.weight.data = self._decompress_bitmask_compressed_weight( + compressed=layer.compressed, + bitmask=layer.bitmask, + layer=layer, + ) + + # compressed and bitmask tensors + # are no longer needed after decompression + del layer.compressed + del layer.bitmask + + # torch.compile workaround + if hasattr(layer, "input_scale"): + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + + if self.weight_quant: + if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: + layer.weight_scale = torch.nn.Parameter( + convert_to_channelwise( + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ), + requires_grad=False, + ) + else: + # torch.compile workaround + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False) + + # Set all negative zero values to 0 prior to compression + if (layer.weight.dtype.is_floating_point + and layer.weight.dtype.itemsize >= 2): + layer.weight.data[layer.weight.data == -0.0] = 0.0 + + w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data) + layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False) + layer.meta = torch.nn.Parameter(meta, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Returns the output tensor for the layer with 2:4 + sparse compressed weights, given the input tensor + and bias + + :param layer: The layer with 2:4 sparse compressed + weights to be used for the computation + :param x: The input tensor to the layer + :param bias: The bias to be added to the output tensor + :return: The output tensor of the layer + """ + if self.quantized: + scale = None + if hasattr(layer, "input_scale"): + scale = layer.input_scale + + if self.weights_dtype == torch.int8: + ops_output = ops.scaled_int8_quant(x, scale=scale) + q_input = ops_output[0] + input_scale = ops_output[1] + else: + assert self.weights_dtype == torch.float8_e4m3fn + if scale is not None: + q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) + else: + q_input, input_scale = ops.scaled_fp8_quant( + x, use_per_token_if_dynamic=True) + + else: + # Not quantized, nothing to do with the input_scales, use as is + input_scale = layer.input_scale + q_input = x + + out = ops.cutlass_scaled_sparse_mm( + a=q_input, + bt_nzs=layer.weight, + bt_meta=layer.meta, + scale_a=input_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias, + ) + + assert out.is_contiguous() + return out + + def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: + if not self.quantized: + return params_dtype + + assert self.weight_quant is not None + assert self.input_quant is not None + + is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8 + + if not is_8_bits: + raise ValueError("Cutlass only supports 8-bit quantization") + + if (self.weight_quant.type == QuantizationType.FLOAT + and self.input_quant.type == QuantizationType.FLOAT): + return torch.float8_e4m3fn + + if (self.weight_quant.type == QuantizationType.INT + and self.input_quant.type == QuantizationType.INT): + return torch.int8 + + raise ValueError("Quantization type not supported by Cutlass") + + def _decompress_bitmask_compressed_weight( + self, + compressed: torch.Tensor, + bitmask: torch.Tensor, + layer: torch.nn.Module, + ) -> torch.Tensor: + """ + Decompress a compressed 2:4 sparse weight tensor using the bitmask and + return the result. + + This function also supports sharded decompression. + + :param compressed: The 2:4 sparse weight tensor compressed using the + sparse-24-bitmask compressor. This is different from + `cutlass_sparse_compress` which uses a different scheme (2 bits for + every nonzero element that represent the coordinate within the block + of 4). The bitmask compression here uses a bitmask to indicate the + positions of non-zero elements. + :param bitmask: The 2:4 bitmask associated with the compressed weights, + representing the positions of non-zero elements in the compressed + tensor. + :param layer: The layer whose weights need to be processed after + loading. + :return: The decompressed 2:4 sparse weight tensor. + """ + + sparsity_compressor = self.model_compressor.sparsity_compressor + + def _process_split( + bitmask_compressed_weight: torch.Tensor, + shape, + bitmask: torch.Tensor, + ) -> torch.Tensor: + weight_data = dict( + compressed=bitmask_compressed_weight, + shape=shape, + bitmask=bitmask, + ) + return sparsity_compressor.decompress_weight(weight_data) + + split_weights: list[torch.Tensor] = [] + split_bitmask: list[torch.Tensor] = [] + split_shape: list[tuple[int, int]] = [] + + if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): + split_weights = torch.split(compressed, layer.logical_widths) + split_bitmask = torch.split(bitmask, layer.logical_widths) + split_shape = [(out, layer.input_size_per_partition) + for out in layer.logical_widths] + + if split_weights: + decompressed_shards = [ + _process_split(compressed_weight, shape, bitmask) + for compressed_weight, shape, bitmask in zip( + split_weights, split_shape, split_bitmask) + ] + decompressed = combine_shards(decompressed_shards) + else: + decompressed = sparsity_compressor.decompress_weight( + dict( + compressed=compressed, + shape=( + layer.logical_widths[0], + layer.input_size_per_partition, + ), + bitmask=bitmask, + )) + return decompressed diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000..a5d48f2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py new file mode 100644 index 0000000..3f3e766 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +__all__ = ["CompressedTensorsW4A16Sparse24"] +W4A16SPARSE24_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, +} +W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None): + self.strategy = strategy + self.group_size = group_size + self.tile_size = 16 + + if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + + self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] + + if self.strategy == "group" and self.group_size is None: + raise ValueError( + "group_size must be given when using strategy group") + + @classmethod + def get_min_capability(cls) -> int: + # ampere + up + return 80 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile to be torch.nn.Parameter + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + layer.scale_packed = Parameter(layer.scale_packed.data, + requires_grad=False) + layer.meta = Parameter(layer.meta.data, requires_grad=False) + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + assert params_dtype == torch.float16, ( + "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 + ) + + pack_factor = 32 // self.quant_type.size_bits + output_size_per_partition = sum(output_partition_sizes) + + qweight = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=pack_factor, + marlin_tile_size=self.tile_size, + weight_loader=weight_loader) + + input_groups = (1 if self.group_size is None else + input_size_per_partition // self.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if self.group_size is not None: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + else: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", qweight) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("scale_packed", scales) + layer.register_parameter("meta", meta) + + max_workspace_size = ( + output_size_per_partition // + GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + + workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), + requires_grad=False) + layer.workspace = workspace + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + qweight = layer.weight_packed + meta = layer.meta + scales = layer.scale_packed + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, self.quant_type, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py new file mode 100644 index 0000000..96dccf0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW4A16Fp4"] + + +class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): + + def __init__(self, has_input_global_scale: bool = False): + self.has_input_global_scale = has_input_global_scale + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + # dont restrict as emulations + return 80 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + if self.has_input_global_scale: + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_global_scale", input_global_scale) + + def process_weights_after_loading(self, layer) -> None: + # Process parameters for marlin repacking + + # Rename weight_packed to weight that marlin expects + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + del layer.weight_packed + # Rename weight_global_scale to weight_scale_2 that marlin expects + # Note: ct stores the inverse of what is expected by the marlin kernel + layer.weight_scale_2 = Parameter( + 1 / layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + del layer.weight_global_scale + + if self.has_input_global_scale: + layer.input_global_scale = torch.nn.Parameter( + layer.input_global_scale.data, requires_grad=False) + + prepare_fp4_layer_for_marlin(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return apply_fp4_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 0000000..8ba7216 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable, Optional + +import torch +from torch.nn.parameter import Parameter + +import vllm.envs as envs +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 + run_nvfp4_emulations) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + if envs.VLLM_USE_NVFP4_CT_EMULATIONS: + return 80 + return 100 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_global_scale", input_global_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer) -> None: + + global_input_scale = layer.input_global_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(global_input_scale, + requires_grad=False) + + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) + + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + # required by cutlass kernel; need Parameter, not ModelWeightParameter + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + + layer.alpha = Parameter(layer.input_global_scale * + layer.weight_global_scale, + requires_grad=False) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if envs.VLLM_USE_NVFP4_CT_EMULATIONS: + out = run_nvfp4_emulations( + x=x, + input_global_scale=layer.input_global_scale, + weight=layer.weight, + weight_scale_swizzled=layer.weight_scale_swizzled, + weight_global_scale=layer.weight_global_scale) + if bias is not None: + out = out + bias + return out + + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) + + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, + layer.weight_scale_swizzled, + 1 / layer.alpha, output_dtype) + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py new file mode 100644 index 0000000..01a87a0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW8A16Fp8"] + +SUPPORTED_STRATEGIES = [ + QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR +] + + +class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + # W8A8-Fp8 kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales, + # we expand each scale to its shard's channels. + def process_weights_after_loading(self, layer) -> None: + if self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, + requires_grad=False) + else: + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + # Weights must be transposed for marlin + layer.weight = torch.nn.Parameter(layer.weight.t(), + requires_grad=False) + + if self.is_static_input_scheme: + # required by torch.compile to be torch.nn.Parameter + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + prepare_fp8_layer_for_marlin(layer) + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + else: + raise ValueError( + f"Unsupported weight strategy={self.strategy}, " + f"supported strategies are {SUPPORTED_STRATEGIES}") + + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE (to deal with converted checkpoints) + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py new file mode 100644 index 0000000..1e61e05 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +__all__ = ["CompressedTensorsW8A8Fp8"] + + +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.out_dtype = torch.get_default_dtype() + self.is_static_input_scheme = is_static_input_scheme + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.strategy == QuantizationStrategy.TENSOR: + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) + + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.strategy == QuantizationStrategy.CHANNEL: + weight = layer.weight + + if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) + + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # INPUT SCALE + if self.is_static_input_scheme and hasattr(layer, 'input_scale'): + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py new file mode 100644 index 0000000..f493cc1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy +import os + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import apply_int8_linear + +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Int8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, strategy: str, is_static_input_scheme: bool, + input_symmetric: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + self.input_symmetric = input_symmetric + + @classmethod + def get_min_capability(cls) -> int: + # turing and up + return 75 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + layer.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), + is_static_input_scheme=self.is_static_input_scheme, + input_symmetric=self.input_symmetric) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW8A8Int8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: compressed-tensors stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + # return self.kernel.apply_weights(layer, x, bias) + + return apply_int8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + input_zero_point=layer.input_zero_point, + azp_adj=layer.azp_adj, + bias=bias, + w8a8_strategy=self.w8a8_strategy) + diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py new file mode 100644 index 0000000..7478760 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +# yapf: enable +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsWNA16"] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128 +} +WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size == -1 and self.strategy != "channel": + raise ValueError("Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") + + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric else + WNA16_SUPPORTED_TYPES_MAP[num_bits]) + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ) + } + + zeros_args = { + "weight_loader": + weight_loader, + "data": + torch.zeros( + output_size_per_partition // self.pack_factor, + scales_and_zp_size, + dtype=torch.int32, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + + if not self.symmetric: + qzeros = PackedColumnParameter(output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + if not self.symmetric: + qzeros = PackedvLLMParameter(input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + if not self.symmetric: + layer.register_parameter("weight_zero_point", qzeros) + + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py new file mode 100644 index 0000000..d926b4c --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.triton_utils import tl, triton + + +def is_weak_contiguous(x: torch.Tensor): + strides = x.stride() + sizes = x.shape + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) + return is_transpose or is_not_transpose + + +@triton.jit +def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, + M, N, K, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr): + pid = tl.program_id(axis=0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = ACCUMULATOR_DTYPE + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # NOTE: Some tensor inputs are so large, they will cause int32 overflow + # so it is necessary to use tl.int64 for all the offsets, else SEGV will + # eventually occur. + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + masks_am = offsets_am < M + + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + masks_bn = offsets_bn < N + + offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + offsets_a = (stride_am * offsets_am[:, None] + + stride_ak * offsets_k[None, :]) + offsets_b = (stride_bk * offsets_k[:, None] + + stride_bn * offsets_bn[None, :]) + + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create + # appropriate offsets and masks for each case. Same goes for + # BLOCK_SIZE_SCALE_B. + offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) + masks_scale_am = offsets_scale_am < M + + offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) + masks_scale_bn = offsets_scale_bn < N + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + scale_a_ptrs = scale_a_ptr + offsets_scale_am + scale_b_ptrs = scale_b_ptr + offsets_scale_bn + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale at end. + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) + # Need to broadcast to the appropriate size, if scale_a is already + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes + # for scale_b below. + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) + accumulator = scale_a * accumulator.to(tl.float32) + + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) + accumulator = scale_b.T * accumulator.to(tl.float32) + + # Convert to output format. + c = accumulator.to(c_ptr.type.element_ty) + + # Add bias, it's already in output format, so add it after conversion. + if bias_ptr: + offsets_bias = offsets_bn + bias_ptrs = bias_ptr + offsets_bias + bias_mask = offsets_bias < N + bias = tl.load(bias_ptrs, bias_mask) + c += bias + + # Save output + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + tl.store(c_ptrs, c, mask=c_mask) + + +# input - [M, K] +# weight - [K, N] +def triton_scaled_mm(input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True) -> torch.Tensor: + M, K = input.shape + N = weight.shape[1] + + assert N > 0 and K > 0 and M > 0 + assert weight.shape[0] == K + assert input.dtype == weight.dtype + + scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a + scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b + + assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 + or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 + or scale_b.shape[0] == N) + assert out_dtype.is_floating_point + assert bias is None or bias.is_floating_point() + assert is_weak_contiguous(input) + assert is_weak_contiguous(weight) + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), ) + + result = torch.empty((M, N), dtype=out_dtype, device=input.device) + + has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) + else: + tile_shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = tile_shape + + block_size_sa = 1 if has_scalar(scale_a) else block_size_m + block_size_sb = 1 if has_scalar(scale_b) else block_size_n + + accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 + + # A = input, B = weight, C = result + # A = M x K, B = K x N, C = M x N + scaled_mm_kernel[grid](input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb) + + return result.to(out_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py new file mode 100644 index 0000000..099d861 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -0,0 +1,216 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Mapping +from types import MappingProxyType +from typing import Optional + +import regex as re +from compressed_tensors import CompressionFormat +from torch.nn import Module + + +def is_activation_quantization_format(format: str) -> bool: + _ACTIVATION_QUANTIZATION_FORMATS = [ + CompressionFormat.naive_quantized.value, + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value, + CompressionFormat.nvfp4_pack_quantized.value + ] + return format in _ACTIVATION_QUANTIZATION_FORMATS + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str] = tuple(), + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping and layer_name not in ignore: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def find_matched_target( + layer_name: Optional[str], + module: Module, + targets: Iterable[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) +) -> str: + """ + Helper function to look up which "target" in the compressed-tensors + config that a layer corresponds to. + + Recall that a compressed-tensors configs has a concept of + config_groups, where each layer can be quantized with with a different + scheme. + + targets in each config_group will be a list of either layer names + (or regexes corresponding to layer names) or names of torch Modules. + + First, we try to match the layer_name with a target + Second, we try to match the module's name with a target + Third, we try to map the layer_name to a list of fused module names. + *All* component module names must match in order for a match to be + successful. A successful match returns the first component target + + :param layer_name: layer name + :param module: torch.nn.Module + :param targets: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + :param fused_strategy: either "all" or "any". If using "all", fused + layers match if "all" of its components match + """ + + if layer_name is None: + layer_name = "" + + matched_target = ( + _find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, True) + or _match_fused_layer(layer_name, targets, fused_mapping)) + + if matched_target is None: + raise ValueError( + f"Unable to find matching target for {layer_name} in the " + "compressed-tensors config.") + + return matched_target + + +def _find_first_match(value: str, + targets: Iterable[str], + check_contains: bool = False) -> Optional[str]: + """ + Returns first element of target that matches value either + exactly or as a regex after 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + + :param value: string to compare the list of targets against + :param targets: list of targets to match the layer against + :param check_contains: whether or not to do a substring match + """ + + for target in targets: + if _is_equal_or_regex_match(value, + target, + check_contains=check_contains): + return target + return None + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False + + +def _match_fused_layer( + layer_name: str, target_layers: Iterable[str], + fused_mapping: Mapping[str, list[str]]) -> Optional[str]: + """ + Match a fused layer name to its corresponding individual layer in + target_layers. Returns first value in fused_mapping which matches targets + + Implements an "all" matching strategy where a fused layer matches iff + "all" of its components match + + :param layer_name: layer name + :param target_layers: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + + Examples: + layer_name = "model.layers.0.self_attn.qkv_proj" + target_layers = ["model.layers.0.self_attn.q_proj", + "model.layers.0.self_attn.k_proj", + "model.layers.0.self_attn.v_proj"] + """ + # find layer_name in mapping + fused = next((key for key in fused_mapping if layer_name.endswith(key)), + None) + if fused is None: + return None + + # expand path of unfused components + unfused_paths = [ + layer_name.replace(fused, unfused) for unfused in fused_mapping[fused] + ] + + # for each unfused component, find a match in targets + unfused_matches: list[Optional[str]] = [] + for unfused in unfused_paths: + for target in target_layers: + if _is_equal_or_regex_match(unfused, target): + unfused_matches.append(target) + break + else: + unfused_matches.append(None) + + return unfused_matches[0] if all(unfused_matches) else None diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_1536_7168_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_1536_7168_BW200.json new file mode 100644 index 0000000..4d58a0d --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_1536_7168_BW200.json @@ -0,0 +1,244 @@ +{ + "1536_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_1536_7168_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_1536_7168_K100_AI.json new file mode 100644 index 0000000..a980fb8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_1536_7168_K100_AI.json @@ -0,0 +1,244 @@ +{ + "1536_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_3072_1536_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_3072_1536_BW200.json new file mode 100644 index 0000000..18edaa3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_3072_1536_BW200.json @@ -0,0 +1,244 @@ +{ + "3072_1536": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_3072_1536_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_3072_1536_K100_AI.json new file mode 100644 index 0000000..60559d0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_3072_1536_K100_AI.json @@ -0,0 +1,244 @@ +{ + "3072_1536": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_4096_512_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4096_512_BW200.json new file mode 100644 index 0000000..049b682 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4096_512_BW200.json @@ -0,0 +1,244 @@ +{ + "4096_512": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_4096_512_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4096_512_K100_AI.json new file mode 100644 index 0000000..42c6b08 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4096_512_K100_AI.json @@ -0,0 +1,244 @@ +{ + "4096_512": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_4608_7168_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4608_7168_BW200.json new file mode 100644 index 0000000..f6f771f --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4608_7168_BW200.json @@ -0,0 +1,244 @@ +{ + "4608_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 2, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_4608_7168_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4608_7168_K100_AI.json new file mode 100644 index 0000000..a9177cb --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_4608_7168_K100_AI.json @@ -0,0 +1,244 @@ +{ + "4608_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_512_7168_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_512_7168_BW200.json new file mode 100644 index 0000000..d6777ce --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_512_7168_BW200.json @@ -0,0 +1,244 @@ +{ + "512_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_512_7168_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_512_7168_K100_AI.json new file mode 100644 index 0000000..889d994 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_512_7168_K100_AI.json @@ -0,0 +1,244 @@ +{ + "512_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_576_7168_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_576_7168_BW200.json new file mode 100644 index 0000000..6eea894 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_576_7168_BW200.json @@ -0,0 +1,244 @@ +{ + "576_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_576_7168_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_576_7168_K100_AI.json new file mode 100644 index 0000000..e0c6cb9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_576_7168_K100_AI.json @@ -0,0 +1,244 @@ +{ + "576_7168": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2048_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2048_BW200.json new file mode 100644 index 0000000..83f4404 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2048_BW200.json @@ -0,0 +1,244 @@ +{ + "7168_2048": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2048_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2048_K100_AI.json new file mode 100644 index 0000000..f7d9ddd --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2048_K100_AI.json @@ -0,0 +1,244 @@ +{ + "7168_2048": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2304_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2304_BW200.json new file mode 100644 index 0000000..c8e31b7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2304_BW200.json @@ -0,0 +1,244 @@ +{ + "7168_2304": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 2, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2304_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2304_K100_AI.json new file mode 100644 index 0000000..91c7fae --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_2304_K100_AI.json @@ -0,0 +1,244 @@ +{ + "7168_2304": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 8, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 4, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 2, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 8, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_256_BW200.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_256_BW200.json new file mode 100644 index 0000000..ea52f10 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_256_BW200.json @@ -0,0 +1,244 @@ +{ + "7168_256": { + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_256_K100_AI.json b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_256_K100_AI.json new file mode 100644 index 0000000..01039f5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/configs/awq/AWQ_7168_256_K100_AI.json @@ -0,0 +1,244 @@ +{ + "7168_256": { + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "13": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "14": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 0, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "SPLIT_K": 1, + "num_stages": 1, + "num_warps": 4, + "num_ldmatrixes": 1 + } + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py new file mode 100644 index 0000000..5903976 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging + +import torch + +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils import direct_register_custom_op, has_deep_gemm + +if has_deep_gemm(): + import deep_gemm + +logger = logging.getLogger(__name__) + + +def prepare_block_fp8_matmul_inputs( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> tuple[int, int, int, torch.Tensor]: + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + assert A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 + assert B.is_contiguous() + assert Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + return M, N, K, C + + +def w8a8_block_fp8_matmul_deepgemm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, + output_dtype) + # Deepgemm only supports output tensor type as bfloat16 + assert C.dtype == torch.bfloat16 + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + return C + + +def w8a8_block_fp8_matmul_deepgemm_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, + output_dtype) + return C + + +direct_register_custom_op( + op_name="w8a8_block_fp8_matmul_deepgemm", + op_func=w8a8_block_fp8_matmul_deepgemm, + mutates_args=[], + fake_impl=w8a8_block_fp8_matmul_deepgemm_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000..8030be5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "deepspeedfp" + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> list[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py new file mode 100644 index 0000000..47eca80 --- /dev/null +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -0,0 +1,204 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs + + +class ExpertsInt8Config(QuantizationConfig): + """Config class for Int8 experts quantization.""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "experts_int8" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return ExpertsInt8MoEMethod(self) + return None + + +class ExpertsInt8MoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: ExpertsInt8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + int8_dtype = torch.int8 + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=int8_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=int8_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_scale = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ExpertsInt8MoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int8_w8a16=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale) + + @staticmethod + def quantizing_weight_loader(layer, weight_loader): + + def quantize_and_call_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: int, + expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + shard_size = layer.intermediate_size_per_partition + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + device = get_tp_group().device + loaded_weight = loaded_weight.to(device) + # w1, gate_proj case: Load into first shard of w13. + if shard_id == "w1": + scales = quantize_in_place_and_get_scales( + loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, + 0]) + # w3, up_proj case: Load into second shard of w13. + elif shard_id == "w3": + scales = quantize_in_place_and_get_scales( + loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, shard_size:2 * + shard_size].copy_(scales[:, 0]) + # w2, down_proj case: Load into only shard of w2. + elif shard_id == "w2": + scales = quantize_in_place_and_get_scales(loaded_weight[:, + shard]) + layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + + return quantize_and_call_weight_loader + + +def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor: + vmax = torch.iinfo(torch.int8).max + scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax) + + weight.div_(scales) + weight.round_() + weight.clamp_(-vmax, vmax) + + return scales diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py new file mode 100644 index 0000000..3e465ee --- /dev/null +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class FBGEMMFp8Config(QuantizationConfig): + """Config class for FBGEMM Fp8.""" + + def __init__(self, ignore_list: list[str], input_scale_ub: float): + super().__init__() + self.ignore_list = ignore_list if ignore_list else [] + self.input_scale_ub = input_scale_ub + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = not current_platform.has_device_capability(89) + self.fp8_linear = Fp8LinearOp() + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "fbgemm_fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config": + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix=prefix, + ignored_layers=self.ignore_list, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + return FBGEMMFp8LinearMethod(self) + return None + + +class FBGEMMFp8LinearMethod(LinearMethodBase): + + def __init__(self, quant_config: FBGEMMFp8Config): + self.quant_config = quant_config + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.out_dtype = torch.get_default_dtype() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + maybe_create_device_identity() + weight_loader = extra_weight_attrs.get("weight_loader") + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = ChannelQuantScaleParameter(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE UPPER BOUND + input_scale_ub = torch.nn.Parameter(torch.tensor( + (self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False) + layer.input_scale_ub = input_scale_ub + + def process_weights_after_loading(self, layer: Module) -> None: + # required by torch.compile + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + weight = layer.weight + + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + if self.quant_config.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale_ub + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.quant_config.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000..5a1a427 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -0,0 +1,950 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import functools +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.nn.functional as F +from torch.nn import Module +from torch.nn.parameter import Parameter + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, + prepare_moe_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, + cutlass_fp8_supported, maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + requantize_with_max_scale) +from vllm.model_executor.parameter import (BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.utils import has_deep_gemm + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +def _is_col_major(x: torch.Tensor) -> bool: + assert x.dim() == 3 + b, m, n = x.shape + return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[list[str]] = None, + weight_block_size: Optional[list[int]] = None, + ) -> None: + super().__init__() + + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + "The block-wise quantization only supports fp8-serialized " + "checkpoint for now.") + if len(weight_block_size) != 2: + raise ValueError( + "The quantization block size of weight must have 2 " + f"dimensions, but got {len(weight_block_size)} dimensions") + if activation_scheme != "dynamic": + raise ValueError("The block-wise quantization only supports " + "dynamic activation scheme for now, but got " + f"{activation_scheme} activation scheme.") + self.weight_block_size = weight_block_size + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.ignored_layers is not None: + self.ignored_layers = hf_to_vllm_mapper.apply_list( + self.ignored_layers) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], + None) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + # If no matches, return None + return None + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() + self.out_dtype = torch.get_default_dtype() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + + # AITER is only supported on ROCm and only for FP8_FNUZ + # and at the moment are MI300 series + self.use_aiter_and_is_supported = (current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz()) + + self.block_quant = self.quant_config.weight_block_size is not None + self.fp8_linear = Fp8LinearOp( + # Default to using per_token quantization if cutlass is supported + use_per_token_if_dynamic=cutlass_fp8_supported()) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + if self.block_quant: + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_config.weight_block_size is not None + layer.weight_block_size = self.quant_config.weight_block_size + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if (tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + # Required by column parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition + == tp_size) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}.") + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + if not self.block_quant: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + layer.register_parameter("weight_scale", scale) + else: + assert self.quant_config.activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + # The weight_scale_inv name is intentional for deepseekv3 + layer.register_parameter("weight_scale_inv", scale) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"scale_type": "input_scale"}) + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: + # Pad the weight tensor. This is an optimization on ROCm platform, which + # can benefit from tensors located far enough from one another in memory + if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0): + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + torch.cuda.empty_cache() + return weight + + def process_weights_after_loading(self, layer: Module) -> None: + size_k_first = True + # TODO(rob): refactor block quant into separate class. + if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" + size_k_first = False + if current_platform.is_fp8_fnuz(): + weight, weight_scale_inv, _ = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv) + else: + weight = layer.weight.data + weight_scale_inv = layer.weight_scale_inv.data + + weight = self._maybe_pad_weight(weight) + + # Torch.compile cannot use Parameter subclasses. + layer.weight = Parameter(weight, requires_grad=False) + layer.weight_scale_inv = Parameter(weight_scale_inv, + requires_grad=False) + + # If checkpoint not serialized fp8, quantize the weights. + elif not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module + else: + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + + weight = layer.weight + weight_scale = layer.weight_scale + + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + if not self.use_marlin: + # Dequant -> Quant with max scale so we can run per tensor. + if current_platform.is_fp8_fnuz(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + weight = self._maybe_pad_weight(weight) + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer, size_k_first) + # Activations not quantized for marlin. + del layer.input_scale + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + if self.block_quant: + assert self.quant_config.weight_block_size is not None + + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias) + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + + from vllm.model_executor.layers.fused_moe import fused_experts + self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + + # Check for DeepGemm support. + self.allow_deep_gemm = False + if envs.VLLM_USE_DEEP_GEMM: + if not has_deep_gemm(): + logger.warning_once("Failed to import DeepGemm kernels.") + elif not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + " DeepGemm kernels") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(90)): + logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") + self.allow_deep_gemm = True + else: + logger.warning_once( + "DeepGemm not supported on the current platform.") + + # Check for CutlassBlockScaledGroupedGemm support. + self.allow_cutlass_block_scaled_grouped_gemm = False + if not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + "CutlassBlockScaledGroupedGemm kernels") + elif (current_platform.is_cuda() + and current_platform.has_device_capability(100)): + logger.info_once( + "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." + ) + self.allow_cutlass_block_scaled_grouped_gemm = True + else: + logger.warning_once( + "CutlassBlockScaledGroupedGemm not supported on the current " + "platform.") + + self.topk_indices_dtype = None + self.fused_experts = functools.partial( # type: ignore + fused_experts, + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + allow_cutlass_block_scaled_grouped_gemm=( + self.allow_cutlass_block_scaled_grouped_gemm)) + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + if self.block_quant: + assert self.quant_config.weight_block_size is not None + layer.weight_block_size = self.quant_config.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}.") + if (tp_size > 1 + and intermediate_size_per_partition % block_k != 0): + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + if not self.block_quant: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + else: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // + block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK. + value} if self.block_quant else + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + # Lazy import to avoid importing triton too early. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled, shuffle_weights) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + + # TODO (rob): refactor block quant into separate class. + if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" + if current_platform.is_fp8_fnuz(): + w13_weight, w13_weight_scale_inv, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale_inv, + layer.w13_input_scale) + w2_weight, w2_weight_scale_inv, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale_inv, + layer.w2_input_scale) + else: + w13_weight = layer.w13_weight.data + w13_weight_scale_inv = layer.w13_weight_scale_inv.data + w2_weight = layer.w2_weight + w2_weight_scale_inv = layer.w2_weight_scale_inv + + # torch.compile() cannot use Parameter subclasses. + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, + requires_grad=False) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, + requires_grad=False) + if self.rocm_aiter_moe_enabled: + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + # DeepGemm scales need to be transposed and aligned. We try to do + # it ahead of time for performance reasons. + if self.allow_deep_gemm: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = \ + dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = \ + dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + + # If checkpoint is fp16, quantize in place. + elif not self.quant_config.is_checkpoint_fp8_serialized: + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + if self.rocm_aiter_moe_enabled: + # reshaping weights is required for aiter moe kernel. + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if (layer.w13_input_scale is None + or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer.") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = shuffle_weights( + layer.w13_weight, layer.w2_weight) + + layer.w13_weight = torch.nn.Parameter(shuffled_w13, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, + requires_grad=False) + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + + def select_gemm_impl( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + moe: FusedMoEConfig, + ) -> FusedMoEPermuteExpertsUnpermute: + from vllm.model_executor.layers.fused_moe import ( + BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) + + assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( + "Marlin and ROCm AITER are not supported with all2all yet.") + + if (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts): + max_num_tokens_per_rank = ( + prepare_finalize.max_num_tokens_per_rank()) + assert max_num_tokens_per_rank is not None + logger.debug( + "BatchedTritonOrDeepGemmExperts(%s): " + "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", + self.__class__.__name__, max_num_tokens_per_rank, + self.quant_config.weight_block_size, False) + return BatchedTritonOrDeepGemmExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + per_act_token_quant=False, + allow_deep_gemm=self.allow_deep_gemm, + ) + else: + logger.debug( + "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", + self.__class__.__name__, self.quant_config.weight_block_size, + False) + return TritonOrDeepGemmExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + if self.rocm_aiter_moe_enabled: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_fused_experts) + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + use_fp8_w8a8=True, + apply_router_weight_on_input=apply_router_weight_on_input, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + expert_map=expert_map) + elif self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + else: + return self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py new file mode 100644 index 0000000..86da04c --- /dev/null +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import gguf +import torch +from gguf import GGMLQuantizationType as WeightType +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import direct_register_custom_op + +logger = init_logger(__name__) + + +class GGUFConfig(QuantizationConfig): + """Config class for GGUF.""" + + def __init__(self, ) -> None: + super().__init__() + + def __repr__(self) -> str: + return ("GGUFConfig()") + + def get_name(self) -> QuantizationMethods: + return "gguf" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.half, torch.bfloat16, torch.float32] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "GGUFConfig": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return GGUFLinearMethod(self) + elif isinstance(layer, VocabParallelEmbedding): + return GGUFEmbeddingMethod(self) + elif isinstance(layer, FusedMoE): + return GGUFMoEMethod(self) + return None + + +UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} +STANDARD_QUANT_TYPES = { + WeightType.Q4_0, + WeightType.Q4_1, + WeightType.Q5_0, + WeightType.Q5_1, + WeightType.Q8_0, + WeightType.Q8_1, +} +KQUANT_TYPES = { + WeightType.Q2_K, + WeightType.Q3_K, + WeightType.Q4_K, + WeightType.Q5_K, + WeightType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + WeightType.IQ1_M, + WeightType.IQ1_S, + WeightType.IQ2_XXS, + WeightType.IQ2_XS, + WeightType.IQ2_S, + WeightType.IQ3_XXS, + WeightType.IQ3_S, + WeightType.IQ4_XS, + WeightType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: + if qweight_type in IMATRIX_QUANT_TYPES: + mmvq_safe = 8 if qweight.shape[0] > 5120 else 16 + else: + mmvq_safe = 2 if qweight.shape[0] > 5120 else 6 + # HACK: when doing chunked prefill we don't generate output tokens + # so input to logits generator is empty which causes invalid parameter + if x.shape[0] == 0: + return torch.empty(x.shape[0], + qweight.shape[0], + dtype=x.dtype, + device=x.device) + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + # enable MMVQ in contiguous batching with batch_size=1 + if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES: + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # Use MMQ Kernel if it's available (standard + k-quants) + elif qweight_type in MMQ_QUANT_TYPES: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + y = x @ weight.T + else: + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") + return y + + +def _fused_mul_mat_gguf_fake( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, +) -> torch.Tensor: + return torch.empty(x.shape[0], + qweight.shape[0], + dtype=x.dtype, + device=x.device) + + +try: + direct_register_custom_op( + op_name="_fused_mul_mat_gguf", + op_func=_fused_mul_mat_gguf, + mutates_args=[], + fake_impl=_fused_mul_mat_gguf_fake, + ) + fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf + +except AttributeError as error: + raise error + + +def _fused_moe_gguf( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + qweight_type: int, + qweight_type2: int, + activation: str, +) -> torch.Tensor: + + def act(x: torch.Tensor): + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if activation == "silu": + torch.ops._C.silu_and_mul(out, x) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(out, x) + else: + raise ValueError(f"Unsupported activation: {activation}") + return out + + # lazy import to avoid triggering triton import in CPU backend + from vllm.model_executor.layers.fused_moe.fused_moe import ( + moe_align_block_size) + + out_hidden_states = torch.empty_like(x) + # unless we decent expert reuse we are better off running moe_vec kernel + if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES + and x.shape[0] > 64): + num_tokens, _ = x.shape + E, N, _ = w1.shape + top_k = topk_ids.shape[1] + BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type) + + sorted_token_ids, expert_ids, num_tokens_post_padded = \ + moe_align_block_size(topk_ids, BLOCK_SIZE, E) + out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids, + num_tokens_post_padded, qweight_type, N, top_k, + num_tokens) + out = act(out) + out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids, + num_tokens_post_padded, qweight_type2, + w2.shape[1], 1, num_tokens * top_k) + out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( + topk_weights.view(num_tokens, top_k, 1)) + ops.moe_sum(out, out_hidden_states) + elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES: + num_tokens, _ = x.shape + E, N, _ = w1.shape + top_k = topk_ids.shape[1] + + out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, + num_tokens) + out = act(out) + + out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2, + w2.shape[1], num_tokens * top_k) + out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( + topk_weights.view(num_tokens, top_k, 1)) + ops.moe_sum(out, out_hidden_states) + else: + logger.warning_once("There is no support for fast MoE kernel " + "for current quantization method. " + "Falling back to slow implementation. ") + for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): + inp = x[tok].reshape((1, ) + x.shape[1:]) + current_hidden_state = None + for ww, ii in zip(w, idx): + expert_up = w1[ii] + + out = fused_mul_mat_gguf(inp, expert_up, qweight_type) + out = act(out) + + expert_down = w2[ii] + current_state = fused_mul_mat_gguf(out, expert_down, + qweight_type2).mul_(ww) + if current_hidden_state is None: + current_hidden_state = current_state + else: + current_hidden_state.add_(current_state) + out_hidden_states[tok] = current_hidden_state + return out_hidden_states + + +def _fused_moe_gguf_fake( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + qweight_type: int, + qweight_type2: int, + activation: str, +) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="_fused_moe_gguf", + op_func=_fused_moe_gguf, + mutates_args=[], + fake_impl=_fused_moe_gguf_fake, + ) + fused_moe_gguf = torch.ops.vllm._fused_moe_gguf + +except AttributeError as error: + raise error + + +def _apply_gguf_embedding( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, + hidden_size: int, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if qweight_type in UNQUANTIZED_TYPES: + return torch.embedding(qweight, x) + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + x_flat = x.flatten() + assert (hidden_size == qweight.shape[1] // type_size * block_size) + quant = torch.index_select(qweight, dim=0, index=x_flat) + dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, + x_flat.shape[0], dtype) + return dequant.view(*x.shape, hidden_size) + else: + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") + + +def _apply_gguf_embedding_fake( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, + hidden_size: int, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device) + + +try: + direct_register_custom_op( + op_name="_apply_gguf_embedding", + op_func=_apply_gguf_embedding, + mutates_args=[], + fake_impl=_apply_gguf_embedding_fake, + ) + apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding + +except AttributeError as error: + raise error + + +class GGUFLinearMethod(LinearMethodBase): + """Linear method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + self.params_dtype = params_dtype + output_size_per_partition = sum(output_partition_sizes) + + tensor_shape = (output_size_per_partition, input_size_per_partition) + qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + "shard_id": [], + "shard_id_map": {}, + }) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qweight", qweight) + + qweight_type = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "shard_weight_type": {}, + "ignore_warning": True + }) + set_weight_attrs(qweight_type, extra_weight_attrs) + layer.register_parameter("qweight_type", qweight_type) + + def process_weights_after_loading(self, layer: torch.nn.Module): + qweight_type = layer.qweight_type.weight_type + if not (qweight_type in UNQUANTIZED_TYPES + or qweight_type in DEQUANT_TYPES): + qweight_type = WeightType(qweight_type) + raise ValueError( + f"Unsupported GGUF quantization type {qweight_type} in " + f"layer {layer}.") + # For MergedColumnParallelLinear and QKVParallelLinear, we need to + # materialize the padded weight parameter for CUDA Graph compatibility. + self._create_padded_weight_param(layer) + + def _create_padded_weight_param(self, layer: torch.nn.Module): + """Create padded weight parameter for GGUF MergedLinear layer.""" + qweight = layer.qweight + shard_id_map = qweight.shard_id_map + shard_id = qweight.shard_id + if len(data_container := qweight.data_container) > 1: + dtype = {data.dtype for data in data_container} + assert len(dtype) == 1, ValueError( + f"Data container has mixed dtypes: {dtype}") + dtype = next(iter(dtype)) + # concat dim0 and pad dim1 + padded_side = max(x.size(1) for x in data_container) + concat_side = sum(x.size(0) for x in data_container) + # Pad the quantized weights to dense tensor, and create a map + # with the location of each shard in the padded tensor. + padded_data = torch.zeros((concat_side, padded_side), + dtype=dtype, + device=qweight.device) + # (dim0_start, dim0_end, dim1_size) + shard_offset_map = dict[str, tuple[int, int, int]]() + for idx in shard_id: + id_in_container = shard_id_map[idx] + start = sum( + x.size(0) for x in data_container[:id_in_container]) + end = start + data_container[id_in_container].size(0) + size = data_container[id_in_container].size(1) + padded_data[start:end, :size] = data_container[id_in_container] + shard_offset_map[idx] = (start, end, size) + qweight.data_container.clear() + padded_param = Parameter(padded_data, requires_grad=False) + set_weight_attrs(padded_param, vars(qweight)) + set_weight_attrs(padded_param, + {"shard_offset_map": shard_offset_map}) + layer.register_parameter("qweight", padded_param) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + shard_id = layer.qweight.shard_id + + if shard_id: + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight + result = [] + for idx in shard_id: + start, end, offset = layer.qweight.shard_offset_map[idx] + qweight_type = layer.qweight_type.shard_weight_type[idx] + result.append( + fused_mul_mat_gguf( + x, qweight[start:end, :offset].contiguous(), + qweight_type)) + out = torch.cat(result, axis=1) + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + out = fused_mul_mat_gguf(x, qweight, qweight_type) + if bias is not None: + out.add_(bias) + return out + + +class GGUFMoEMethod(FusedMoEMethodBase): + """MoE method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + tensor_shape = (num_experts, 2 * intermediate_size_per_partition, + hidden_size) + #gate up proj + w13_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w13_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w13_qweight, extra_weight_attrs) + layer.register_parameter("w13_qweight", w13_qweight) + + w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w13_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + set_weight_attrs(w13_qweight_type, extra_weight_attrs) + layer.register_parameter("w13_qweight_type", w13_qweight_type) + + tensor_shape = (num_experts, intermediate_size_per_partition, + hidden_size) + #gate down proj + w2_qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + w2_qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + }) + set_weight_attrs(w2_qweight, extra_weight_attrs) + layer.register_parameter("w2_qweight", w2_qweight) + + w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), + requires_grad=False) + set_weight_attrs(w2_qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "ignore_warning": True + }) + + set_weight_attrs(w2_qweight_type, extra_weight_attrs) + layer.register_parameter("w2_qweight_type", w2_qweight_type) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GGUFMoEMethod` yet.") + + assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused GGUF MoE method.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, + topk_weights, topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, activation) + + +class GGUFEmbeddingMethod(GGUFLinearMethod): + """Embedding method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def embedding(self, layer: torch.nn.Module, + x: torch.Tensor) -> torch.Tensor: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + hidden_size = qweight.tensor_shape[1] + + return apply_gguf_embedding(x, + qweight, + qweight_type, + hidden_size, + dtype=self.params_dtype) + + +class GGUFUninitializedParameter(UninitializedParameter): + cls_to_become = Parameter + data_container: list[torch.Tensor] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py new file mode 100644 index 0000000..d3ab1be --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import enum +from enum import Enum +from fractions import Fraction +from typing import Any, Optional, Union + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_linear_quant_method) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + lm_head_quantized: bool, + dynamic: dict[str, dict[str, Union[int, bool]]], + ) -> None: + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is dict[str, dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + super().__init__() + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.pack_factor = Fraction(32, self.weight_bits) + if self.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}), " + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "GPTQConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, lm_head_quantized, + dynamic) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQLinearMethod"]: + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() + + +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. + + Args: + quant_config: The GPTQ quantization config. + """ + + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) + if (output_size_per_partition % self.quant_config.pack_factor.numerator + != 0): + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExllamaState.UNINITIALIZED + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if (input_size != input_size_per_partition + and self.quant_config.group_size != -1): + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + exllama_state = ExllamaState.UNUSED + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + g_idx = RowvLLMParameter(data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + qzeros_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.exllama_state = exllama_state + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if layer.exllama_state == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) + else: + layer.g_idx.data = torch.empty((0, ), + dtype=torch.int, + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, + self.quant_config.weight_bits) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, + self.quant_config.weight_bits) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py new file mode 100644 index 0000000..caeb266 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -0,0 +1,446 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + BitBLASLinearKernel, MPLinearLayerConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks, + check_bitblas_supported, verify_bitblas_supported) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class GPTQBitBLASConfig(QuantizationConfig): + """Config class for GPTQ BitBLAS""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + TORCH_DTYPE = torch.float16 + GPTQ_CKPT_STORAGE_DTYPE = ( + "int32" # GPTQ Default Checkpoints use int32 as storage dtype + ) + GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # the gptq_bitblas prefer "quantized" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") + + self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE + + storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE + if c.isdigit())) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "gptq_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "bitblas" + or user_quant == "gptq_bitblas") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return GPTQBitBLASLinearMethod(self) + return None + + @property + def torch_storage_dtype(self) -> torch.dtype: + return self.TORCH_BITBLAS_STORAGE_DTYPE + + @classmethod + def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + # temporarily disable on ROCm platform + if not current_platform.is_cuda(): + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits, + sym)], + group_size=group_size) + + +class GPTQBitBLASLinearMethod(LinearMethodBase): + """Linear method for GPTQ BitBLAS. + + Args: + quant_config: The GPTQ BitBLAS quantization config. + """ + + kernel_type = BitBLASLinearKernel + _kernel_backends_being_used: set[str] = set() + + def __init__(self, quant_config: GPTQBitBLASConfig) -> None: + self.quant_config = quant_config + # Verify supported on platform. + verify_bitblas_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or + if the input size per partition is not divisible by the + group size in `quant_config`. + """ + if params_dtype != torch.float16: + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({self.quant_config.group_size})." + ) + + kernel_type = self.kernel_type + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQBitBLASLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Init buffers + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + # Ignore warning from fused linear layers such as QKVParallelLinear. + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }, + ) + + # Quantized zero-points + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx", + bitblas_quant_config=self.quant_config, + ) + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self.kernel.configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + bias=False, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out = self.kernel.apply_gptq_bitblas_linear(layer, x) + if bias is not None: + out.add_(bias) + return out diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py new file mode 100644 index 0000000..9bed5e2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -0,0 +1,679 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from copy import deepcopy +from typing import Any, Callable, Optional, Union + +import torch + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.gptq_utils import ( + get_dynamic_override, get_linear_quant_method, override_config) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, check_moe_marlin_supports_layer, + marlin_make_workspace_new, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +def get_moe_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + moe_method_cls: type, +): + cloned_config = deepcopy(config) + + if isinstance(layer, FusedMoE): + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + return UnquantizedFusedMoEMethod(layer.moe_config) + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return moe_method_cls(cloned_config) + return None + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool, lm_head_quantized: bool, + dynamic: dict[str, dict[str, Union[int, bool]]], + full_config: dict[str, Any]) -> None: + super().__init__() + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is dict[str, dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.is_sym = is_sym + + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.full_config = full_config + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, + lm_head_quantized, dynamic, config) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "gptq_marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, FusedMoE): + from vllm.model_executor.layers.quantization.moe_wna16 import ( + MoeWNA16Config) + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") + return MoeWNA16Config.from_config( + self.full_config).get_quant_method(layer, prefix) + return get_moe_quant_method(self, layer, prefix, + GPTQMarlinMoEMethod) + return get_linear_quant_method(self, layer, prefix, + GPTQMarlinLinearMethod) + + @classmethod + def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + if not current_platform.is_cuda(): + return False + + if quant_method != "gptq": + return False + + # Marlin conversion is only valid if required properties are found + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size) + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + _kernel_backends_being_used: set[str] = set() + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + # Verify supported on platform. + verify_marlin_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQMarlinLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + if self.quant_config.quant_type.size_bits == 4: + self.quant_type = scalar_types.uint4b8 + elif self.quant_config.quant_type.size_bits == 8: + self.quant_type = scalar_types.uint8b128 + else: + raise ValueError( + "GPTQMarlinMoEMethod only supports int4 and int8 now.") + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + + self.is_k_full = (not self.quant_config.desc_act) or ( + intermediate_size_per_partition == intermediate_size_full) + + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + w2_scales_size = (intermediate_size_full + if self.quant_config.desc_act else + intermediate_size_per_partition) + scales_size2 = (w2_scales_size // self.quant_config.group_size) + strategy = FusedMoeWeightScaleSupported.GROUP.value + else: + scales_size13 = 1 + scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // + self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_scales, + {"load_full_w2": self.quant_config.desc_act}) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + # dont shard the w2 scales when running act order + set_weight_attrs(w2_qzeros, + {"load_full_w2": self.quant_config.desc_act}) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + device = layer.w13_qweight.device + layer.workspace = marlin_make_workspace_new(device, 4) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * + (self.quant_config.group_size if self.quant_config.group_size != -1 + else self.quant_config.pack_factor), + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GPTQMarlinMoEMethod` yet.") + + assert activation == "silu", "Only SiLU activation is supported." + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_type.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + workspace=layer.workspace, + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py new file mode 100644 index 0000000..eba917d --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ + scalar_types.uint4b8, scalar_types.uint8b128 +] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + + +class GPTQMarlin24Config(QuantizationConfig): + """Config class for Marlin24. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + super().__init__() + quant_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }.get(weight_bits) + + self.group_size = group_size + + # Verify + if quant_type is None or \ + quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: + raise ValueError( + f"Marlin_24 does not support quant_type = {quant_type}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") + + self.quant_type = quant_type + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.quant_type.size_bits + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "Marlin24Config(quant_type={}, group_size={})".format( + self.quant_type, self.group_size) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "gptq_marlin_24" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + is_marlin_24_format = ( + hf_quant_cfg.get("checkpoint_format") == "marlin_24") + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_marlin_24") + + if is_marlin_24_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. " + "Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlin24LinearMethod(self) + return None + + +class GPTQMarlin24LinearMethod(LinearMethodBase): + """Linear method for Marlin24. + + Args: + quant_config: The Marlin24 quantization config. + """ + + def __init__(self, quant_config: GPTQMarlin24Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size // 2, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + # Meta + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B_24", qweight) + layer.register_parameter("B_meta", meta) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B_24 = Parameter(layer.B_24.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B_24 + meta = layer.B_meta + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.quant_type, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py new file mode 100644 index 0000000..ee8a0e3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + marlin_make_empty_g_idx, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack +from vllm.model_executor.parameter import (BasevLLMParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class HQQMarlinConfig(QuantizationConfig): + """Config class for HQQ Marlin""" + + def __init__( + self, + weight_bits: int, + group_size: int, + skip_modules: Optional[list[str]] = None, + ) -> None: + super().__init__() + assert group_size == 64, ("The only supported HQQ group size is " + "currently 64.") + assert weight_bits == 4, ("The only supported HQQ quantization " + "bitsize is currently 4.") + + self.weight_bits = weight_bits + self.group_size = group_size + self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format + self.quant_type = scalar_types.uint4 + self.skip_modules = skip_modules + + def __repr__(self) -> str: + return (f"HQQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "hqq" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig": + wq_params = (config["quant_config"]["weight_quant_params"]) + weight_bits = cls.get_from_keys(wq_params, ["nbits"]) + group_size = cls.get_from_keys(wq_params, ["group_size"]) + skip_modules = config["skip_modules"] + return cls(weight_bits, group_size, skip_modules) + + def is_layer_skipped(self, prefix: str) -> bool: + # Split the prefix into its dot-separated components + components = prefix.split('.') + + # Check if any of the skip modules exactly matches any component + return self.skip_modules is not None and any( + module_name in components for module_name in self.skip_modules) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if self.is_layer_skipped(prefix): + return UnquantizedLinearMethod() + return HQQMarlinMethod(self) + return None + + +# Empty HQQ parameter, will be ignored during loading +class HQQEmptyParameter(BasevLLMParameter): + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + pass + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + pass + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + pass + + +def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + raise ValueError("No loader provided for HQQ parameter!") + + +# HQQ packing creates issues with sharding - therefore, prior to loading, we +# repack to GPTQ. We also reshape the weights to their proper GPTQ shape. +class HQQweightParameter(PackedvLLMParameter): + + # unpack function from https://github.com/mobiusml/hqq + def unpack_4bit_u8(self, + W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 + assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)" + + dtype = torch.uint8 + step = W_q.shape[0] + tmp = torch.empty([2 * step, W_q.shape[1]], + dtype=dtype, + device=W_q.device) + tmp[:step] = (W_q & 0b11110000) >> 4 + tmp[step:] = W_q & 0b00001111 + return tmp + + def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, + **kwargs): + super().__init__(packed_factor, packed_dim, None, **kwargs) + self.weight_bits = weight_bits + self.input_shape = self.shape[self.input_dim] * self.packed_factor + self.output_shape = self.shape[self.output_dim] + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( + 1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_merged_column_weight(loaded_weight, **kwargs) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(self.output_shape, + -1).transpose(1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_row_parallel_weight(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( + 1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_qkv_weight(loaded_weight, **kwargs) + + +# Zero points and scales in HQQ must also be reshaped to correspond to W_q's +# GPTQ shape (transposed - we transpose them too when processing weights). +class HQQZeroScaleParameter(GroupQuantScaleParameter): + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = loaded_weight.reshape(-1, self.shape[1]) + super().load_merged_column_weight(loaded_weight, **kwargs) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + loaded_weight = loaded_weight.reshape(self.shape[0], -1) + super().load_row_parallel_weight(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = loaded_weight.reshape(-1, self.shape[1]) + super().load_qkv_weight(loaded_weight, **kwargs) + + +class HQQMarlinMethod(LinearMethodBase): + """Linear method for HQQ Marlin. + """ + + def __init__( + self, + quant_config: HQQMarlinConfig, + ): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + self.output_size_per_partition = sum(output_partition_sizes) + self.input_size_per_partition = input_size_per_partition + + weight_loader = extra_weight_attrs.get("weight_loader", error_loader) + + self.scales_and_zp_size = (input_size_per_partition // + self.quant_config.group_size) + + qweight = HQQweightParameter( + data=torch.empty( + self.input_size_per_partition // self.quant_config.pack_factor, + self.output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_bits=self.quant_config.weight_bits, + weight_loader=weight_loader) + + zeros = HQQZeroScaleParameter(data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + scales = HQQZeroScaleParameter(data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("W_q", qweight) + layer.register_parameter("zero", zeros) + layer.register_parameter("scale", scales) + + # Ignore extra parameters in the HQQ model. + # To be added as needed. + ignore_parameters = ("axis", "channel_wise", "compute_dtype", + "encoded_state_dict", "group_size", "nbits", + "offload_meta", "optimize", "packing", + "quant_scale", "quant_zero", "round_zero", + "shape", "stores_quant_config", + "unpack_view_dtype", "view_as_float") + for name in ignore_parameters: + layer.register_parameter( + name, + HQQEmptyParameter(data=torch.empty(0), + weight_loader=weight_loader)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + dev = layer.W_q.device + + # Repack to Marlin + sort_indices = torch.empty(0, dtype=torch.int, device=dev) + marlin_w_q = ops.gptq_marlin_repack( + layer.W_q, + sort_indices, + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.weight_bits, + ).to(dev) + marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + + layer.g_idx = marlin_make_empty_g_idx(dev) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) + + layer.marlin_qweight = marlin_w_q + layer.marlin_zeros = marlin_zp + layer.marlin_scales = marlin_s + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + workspace = MarlinWorkspace(self.output_size_per_partition, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + scales = layer.marlin_scales + zeros = layer.marlin_zeros + orig_type = x.dtype + + if orig_type != torch.float16: + x = x.to(torch.float16) + scales = scales.to(torch.float16) + zeros = zeros.to(torch.float16) + + marlin_out = ops.gptq_marlin_gemm( + x, + None, + layer.marlin_qweight, + scales, + None, + zeros, + layer.g_idx, + layer.g_idx_sort_indices, + workspace.scratch, + scalar_types.uint4, + x.shape[0], + self.output_size_per_partition, + self.input_size_per_partition, + True, # is_k_full + False, # use atomic add + True, # use 32-bit reduce + True, # use float zp + ) + + if orig_type != torch.float16: + marlin_out = marlin_out.to(orig_type) + + if bias is not None: + marlin_out.add_(bias) + + return marlin_out diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py new file mode 100644 index 0000000..428e9b8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, + is_layer_skipped_awq) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.platforms import current_platform + +MIN_IPEX_VERSION = "2.6.0" + + +class IPEXConfig(QuantizationConfig): + """INT8 quantization config class using IPEX for the CPU/XPU backend, + including AWQ, GPTQ. + """ + + IPEX_QUANT_METHOD_MAP = { + "awq": 1, + "gptq": 0, + } + + def __init__( + self, + method: str, + weight_bits: int, + group_size: int, + modules_to_not_convert: Optional[list[str]] = None, + desc_act: Optional[bool] = None, + lm_head_quantized: Optional[bool] = None, + ) -> None: + super().__init__() + self.method = method + self.weight_bits = weight_bits + self.group_size = group_size + self.modules_to_not_convert = modules_to_not_convert or [] + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.pack_factor = 32 // self.weight_bits + + if self.weight_bits not in [4]: + raise ValueError(f"IPEX quantization supports weight bits [4], " + f"but got {self.weight_bits}.") + + if self.method not in ["awq", "gptq"]: + raise ValueError(f"IPEX quantization supports [awq, gptq], " + f"but got {self.method}.") + + def __repr__(self) -> str: + return (f"IPEXConfig(method={self.method}," + f"weight_bits={self.weight_bits}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "ipex" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @staticmethod + def get_config_filenames() -> list[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "IPEXConfig": + method = cls.get_from_keys(config, ["quant_method"]).lower() + if method == "awq": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, + ["q_group_size", "group_size"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(method, weight_bits, group_size, modules_to_not_convert, + False, False) + # otherwise for gptq + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) + return cls(method, weight_bits, group_size, [], desc_act, + lm_head_quantized) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if not current_platform.is_cpu() and not current_platform.is_xpu(): + return None + + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + if quant_method in ["awq", "gptq"]: + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + if self.method == "awq": + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return IPEXAWQLinearMethod(self) + if self.method == "gptq": + return IPEXGPTQLinearMethod(self) + return None + + +class IPEXGPTQLinearMethod(GPTQLinearMethod): + """GPTQ linear method using IPEX for the CPU/XPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method.") from err + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + layer.ipex_output_size = layer.qweight.shape[-1] + g_idx = layer.g_idx if self.quant_config.desc_act else None + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + +class IPEXAWQLinearMethod(AWQLinearMethod): + """AWQ linear method using IPEX for the CPU/XPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer=layer) + + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method.") from err + + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + + layer.ipex_output_size = layer.qweight.size( + 1) * self.quant_config.pack_factor + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py new file mode 100644 index 0000000..07ecc09 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: tuple[int, int] # [in, out] + partition_weight_shape: tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + if c.zero_points: + assert w_zp_param_name is not None + if c.has_g_idx: + assert w_gidx_param_name is not None + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module) -> tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py new file mode 100644 index 0000000..0bf0d53 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 + AllSparkLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 + BitBLASLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + MacheteLinearKernel, + AllSparkLinearKernel, + MarlinLinearKernel, + BitBLASLinearKernel, + ExllamaLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py new file mode 100644 index 0000000..785e559 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.allspark_utils import ( + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class AllSparkLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if c.has_g_idx: + return False, "Act reordering currently not supported by AllSpark" + + if c.zero_points: + return False, "Zero points currently not supported by AllSpark" + + return check_allspark_supported_dtype_shape( + c.partition_weight_shape[0], # in_features + c.partition_weight_shape[1], # out_features + c.group_size, + c.weight_type, + c.act_type) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + # prepare the parameters required for the kernel + properties = torch.cuda.get_device_properties(device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + gemm_args = {} + gemm_args['sm_count'] = sm_count + gemm_args['sm_version'] = sm_version + + self.gemm_args = gemm_args + + # transform param weight, scale + old_weight_param = getattr(layer, self.w_q_name) + old_scale_param = getattr(layer, self.w_s_name) + + assert isinstance(old_weight_param, BasevLLMParameter) + permute_param_layout_(old_weight_param, + input_dim=0, + output_dim=1, + packed_dim=0) + + assert isinstance(old_scale_param, BasevLLMParameter) + permute_param_layout_(old_scale_param, input_dim=0, output_dim=1) + + # unpack weight from K / 4 x N int32 to K x N uint8 + new_weight_param = torch.nn.Parameter(old_weight_param.data, + requires_grad=False) + new_weight_param.data = new_weight_param.data.t().contiguous().view( + dtype=torch.uint8) + new_weight_param.data = new_weight_param.data.t().contiguous() + + new_scale_param = torch.nn.Parameter(old_scale_param.data, + requires_grad=False) + + # reorder K x N weight as N32K16 format for Ampere W8A16 + new_weight_param.data, new_scale_param.data, _ = \ + ops.allspark_repack_weight( + new_weight_param.data, new_scale_param.data, None, + c.zero_points) + + replace_parameter(layer, self.w_q_name, new_weight_param.data) + replace_parameter(layer, self.w_s_name, new_scale_param.data) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + gemm_args = self.gemm_args + w_q, w_s, _, _ = self._get_weight_params(layer) + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.allspark_w8a16_gemm( + a=reshaped_x, + b_qweight=w_q, + b_scales=w_s, + b_qzeros=None, + n=c.partition_weight_shape[1], + group_size=c.group_size, + sm_count=gemm_args['sm_count'], + sm_version=gemm_args['sm_version'], + CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp=c.zero_points, + n32k16_reorder=True) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py new file mode 100644 index 0000000..649d07b --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, + MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx, + check_bitblas_supports_shape, query_bitblas_supported_quant_types, + unpack_gptq_qweight, unpack_gptq_qzeros) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +logger = init_logger(__name__) + + +class BitBLASLinearKernel(MPLinearKernel): + + OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING: bool = True + MATMUL_LAYOUT: str = "nt" + BITBLAS_DTYPES: dict[torch.dtype, str] = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + bitblas_matmul: object = None + + def __init__( + self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None, + bitblas_quant_config: Optional[QuantizationConfig] = None, + ): + self.quant_config = bitblas_quant_config + super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name, + w_gidx_param_name) + + def repack_bitblas_from_gptq( + self, + b_q_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: Optional[torch.Tensor] = None, + ): + from bitblas.quantization.utils import general_compress + assert self.bitblas_matmul is not None, "bitblas_matmul is None" + + quant_config = self.quant_config + # qweight in gptq old quant linear stored with + # (outfeatures, infeatures), should be transposed. + qweight = b_q_weight.T.contiguous().view( + quant_config.torch_storage_dtype) # type: ignore[union-attr] + intweight = unpack_gptq_qweight( + qweight, + quant_config.weight_bits).contiguous() # type: ignore[union-attr] + if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined] + qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined] + intweight.cpu()).cuda() + # scales in gptq old quant linear stored with + # (infeatures // group_size, outfeatures), should be transposed. + scales = scales.T.contiguous() + + if qzeros is None: + return qweight, scales, None + + # qzeros should be de-quantized to int zeros. + weight_bits = quant_config.weight_bits # type: ignore[union-attr] + intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous() + zeros: Optional[torch.Tensor] = None + zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined] + if zeros_mode == "original": + zeros = intzeros.to(torch.float16).contiguous() + elif zeros_mode == "rescale": + assert zeros is not None, "zeros should not be None" + zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :] + elif zeros_mode == "quantized": + zeros = ( + torch.Tensor( + general_compress( + intzeros.T.contiguous().cpu().numpy(), + weight_bits, + )).to(qweight.device). + to(quant_config.torch_storage_dtype # type: ignore[union-attr] + ).contiguous()) + else: + raise ValueError("Unsupported zeros type: {}".format(zeros_mode)) + + return qweight, scales, zeros + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + + is_bitblas_installed = True + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError: + is_bitblas_installed = False + + if not is_bitblas_installed: + return False, "bitblas is not installed. Please install bitblas "\ + "by running `pip install bitblas>="\ + f"{MINIMUM_BITBLAS_VERSION}`" + + quant_types = query_bitblas_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, (f"Quant type ({c.weight_type}) not supported by" + f" BitBLAS, supported types are: {quant_types}") + + if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: + return False, (f"Group size ({c.group_size}) not supported by " + "BitBLAS, supported group sizes are: " + f"{BITBLAS_SUPPORTED_GROUP_SIZES}") + + return check_bitblas_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + quant_config = self.quant_config + + # Default names since bitblas requires empty parameters for these, + # TODO: remove this requirement from bitblas (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "qzeros" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = bitblas_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device)) + layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device) + + if c.zero_points: + raise NotImplementedError("Zero points not supported by BitBLAS") + else: + setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device)) + + # Repack weights + bitblas_qweight, bitblas_scales, bitblas_qzeros = ( + self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + None if quant_config.is_sym else # type: ignore[union-attr] + layer.qzeros, # type: ignore[union-attr] + )) + replace_parameter(layer, self.w_q_name, bitblas_qweight) + replace_parameter(layer, self.w_s_name, bitblas_scales) + if bitblas_qzeros is not None: + replace_parameter(layer, self.w_zp_name, bitblas_qzeros) + + def configure_bitblas_matmul( + self, + infeatures: int, + outfeatures: int, + params_dtype: torch.dtype, + bias: bool, + ) -> None: + enable_tuning = self.ENABLE_TUNING + layout = self.MATMUL_LAYOUT + bits = self.quant_config.weight_bits # type: ignore[union-attr] + self._configure_bitblas_matmul( + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ) + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + quant_config = self.quant_config + with_scaling = False + with_zeros = False + group_size = quant_config.group_size # type: ignore[union-attr] + zeros_mode = quant_config.zeros_mode # type: ignore[union-attr] + if quant_config.quant_method == "gptq": # type: ignore[union-attr] + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if quant_config.is_sym: # type: ignore[union-attr] + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr] + ) # type: ignore[union-attr] + + matmul_config = MatmulConfig( + M=self.OPT_FEATURES, + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=quant_config. # type: ignore[union-attr] + storage_dtype, # type: ignore[union-attr] + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNING_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNING_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created without tuning. " + logger.info(_message) + else: + _message = f"BitBLAS Operator {config} retrieved from cache." + logger.info(_message) + return bitblas_matmul + + def apply_gptq_bitblas_linear( + self, + layer: torch.nn.Module, + x: torch.Tensor, + ) -> torch.Tensor: + output_size_per_partition = self.config.partition_weight_shape[1] + out_shape = x.shape[:-1] + (output_size_per_partition, ) + args = [x, layer.qweight, layer.scales] + if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined] + args.append(layer.qzeros) + output = self.bitblas_matmul(*args) # type: ignore[operator] + return output.view(out_shape) + + def apply_weights(self, layer, x, bias=None): + NOT_IMPLEMENT_MESSAGE = ( + f"{self.__class__.__name__}.apply_weights is not implemented. " + "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead") + raise NotImplementedError(NOT_IMPLEMENT_MESSAGE) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py new file mode 100644 index 0000000..fef333e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class ExllamaLinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] + # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but + # currently untested so not added to the list + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Exllama, "\ + "when the input features are partitioned across "\ + "devices" + + if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: + return False, "Output features must be a multiple of the pack " \ + "factor (32 / num_bits) so that we can correctly " \ + "pack the zero points" + + if c.act_type != torch.float16: + return False, "Exllama only supports float16 activations" + + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Exllama, supported types are: "\ + f"{cls.SUPPORTED_QUANT_TYPES}" + + if c.full_weight_shape[0] % c.group_size != 0: + return False, f"Group size ({c.group_size}) does not evenly divide"\ + " the number of input features "\ + f"({c.full_weight_shape[0]})" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # For Exllama, we need to set a zero-point tensor if there is not one + if not c.zero_points: + self.w_zp_name = "qzeros" + device = getattr(layer, self.w_q_name).device + groups = c.partition_weight_shape[0] // c.group_size + out_features = c.partition_weight_shape[1] + + if c.weight_type.has_bias(): + # if the type has a bias we have to create a zeros tensor that + # contains the bias values repeated for each group (-1 due to + # a bug in the original GPTQ checkpoint format leading to + # exllama kernel adding 1 to the zero points during inference) + # Documentation of the bug can be found here: + # https://garden.danieldk.eu/GPTQ-Checkpoint-Format + zeros = torch.full((groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device) + else: + raise NotImplementedError( + "A 0 zero-point is not supported by Exllama due to " + "a bug in the original GPTQ checkpoint format leading to " + "exllama kernel adding 1 to the zero points during " + "inference") + zeros = pack_quantized_values_into_int32(zeros, + c.weight_type, + packed_dim=1) + setattr(layer, self.w_zp_name, + torch.nn.Parameter(zeros, requires_grad=False)) + + if c.has_g_idx: + + def transform_w_g_idx(x): + # Exllama wants the permutation array instead of the group + # indices + return torch.argsort(x).to(torch.int) + + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) + else: + self.w_gidx_name = "g_idx" + empty_g_idx = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int, + device=device), + requires_grad=False) + setattr(layer, self.w_gidx_name, empty_g_idx) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + assert self.w_gidx_name is not None + g_idx = getattr(layer, self.w_gidx_name) + + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x_cont = x.data.contiguous() + ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) + return x_cont + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x.to(dtype=c.act_type) + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + + assert w_zp is not None, "Zero points are required by Exllama" + assert w_g_idx is not None, "Group index is required by Exllama" + output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, + c.weight_type.size_bits) + + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py new file mode 100644 index 0000000..12eb9d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from functools import partial +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + check_machete_supports_shape, query_machete_supported_group_sizes, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in query_machete_supported_group_sizes(c.act_type): + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{query_machete_supported_group_sizes(c.act_type)}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + # `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_quantized_values_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + def transform_w_zp(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=1) + w_s = getattr(layer, self.w_s_name).data + # pre-apply scales to zero-points + x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + if c.zero_points: + self._transform_param(layer, self.w_zp_name, transform_w_zp) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=w_zp, + b_group_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py new file mode 100644 index 0000000..1597492 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace_new, marlin_permute_scales, marlin_sort_g_idx, + marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace_new(device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + grouped_k = (c.partition_weight_shape[0] // + c.group_size if c.group_size != -1 else 1) + self._transform_param(layer, self.w_zp_name, lambda x: \ + marlin_zero_points( + unpack_cols(x.t(), c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1]), + size_k=grouped_k, + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py new file mode 100644 index 0000000..9ebf5f3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class ScaledMMLinearLayerConfig: + is_channelwise: bool + is_static_input_scheme: bool + input_symmetric: bool + + +class ScaledMMLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, + w_s_param_name: str, i_s_param_name: str, + i_zp_param_name: str, azp_adj_param_name: str) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.i_s_name = i_s_param_name + self.i_zp_name = i_zp_param_name + self.azp_adj_name = azp_adj_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _get_weight_params( + self, layer: torch.nn.Module) -> tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.i_s_name), + getattr(layer, self.i_zp_name), + getattr(layer, self.azp_adj_name), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 0000000..18f5ce0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( + TritonScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( + XLAScaledMMLinearKernel) +from vllm.platforms import PlatformEnum, current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], + PlatformEnum.TPU: [XLAScaledMMLinearKernel], +} + + +def choose_scaled_mm_linear_kernel( + config: ScaledMMLinearLayerConfig, + compute_capability: Optional[int] = None +) -> type[ScaledMMLinearKernel]: + """ + Choose an ScaledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (ScaledMMLinearLayerConfig): Description of the linear layer + to be implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the + compute capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + type[ScaledMMLinearKernel]: Chosen kernel. + """ + + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "ScaledMM linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py new file mode 100644 index 0000000..165548a --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_rocm(): + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "currently supported on non-ROCm platform.") + + try: + import aiter # noqa: F401 # deliberately attempt to import aiter + except Exception: + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "installed on ROCm.") + # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled + if not ( + envs.VLLM_ROCM_USE_AITER_LINEAR \ + and envs.VLLM_ROCM_USE_AITER + ): + return (False, "AiterScaledMMLinearKernel is disabled. " + + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.") + + if not c.input_symmetric: + return (False, + "AiterScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + `AiterScaledMMLinearKernel` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + Currently only support per-tensor-per-tensor GEMM + and per-token-per-channel GEMM through AITER + w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support + ATIER block scaled GEMM and mix-precision GEMM. + """ + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + assert symmetric, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + assert x_zp is None, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + out_dtype = x.dtype + + assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == w_q.shape[ + 1] and bias.dtype == out_dtype + + m = x_q.shape[0] # a + n = w_q.shape[1] # b + + per_tensor_scale_a = (x_s.numel() == 1) + per_tensor_scale_b = (w_s.numel() == 1) + per_token_scale_a = (x_s.numel() == m) + per_channel_scale_b = (w_s.numel() == n) + + # @TODO: + # Maybe broadcast the per-tensor-scale into per-channel-scale + # if one of the scale is a per-channel-scale. + # For now, it only supports: + # - per-tensor-per-tensor a8w8 scaled GEMM, and + # - per-token-per-channel a8w8 scaled GEMM + assert ((per_tensor_scale_a and per_tensor_scale_b) + or (per_token_scale_a and per_channel_scale_b)), ( + "Currently only support per-tensor-per-tensor GEMM " + + " and per-token-per-channel GEMM through AITER" + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + + "does not support AITER block scaled GEMM.") + + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py new file mode 100644 index 0000000..b865c5f --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + +from lmslim.layers.gemm.int8_utils import per_token_quant_int8 + + +class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and not current_platform.is_cpu()): + return False, "CutlassScaledMM requires running on CUDA or CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.config.input_symmetric: + weight = getattr(layer, self.w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, self.i_zp_name) * azp_adj + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + if i_s is None and i_zp is None and symmetric is True: + x_q, x_s=per_token_quant_int8(x) + x_zp =None + + else: + x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), + i_s, + i_zp, + symmetric=symmetric) + + if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = i_zp is not None + azp = None if static else x_zp + return ops.cutlass_scaled_mm_azp(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias) + return ops.cutlass_scaled_mm(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py new file mode 100644 index 0000000..817565c --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + if current_platform.is_cpu(): + return ( + False, + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.") + if not c.input_symmetric: + return (False, + "TritonScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return super().apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py new file mode 100644 index 0000000..3de28af --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import warnings +from typing import Optional + +import torch +from functorch.experimental.control_flow import cond # noqa: F401 + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class XLAScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "TPU platform does have a concept of compute capability, " + "this method should not be called.") + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + + if not current_platform.is_tpu(): + return False, "ScaledMMXLA requires running on TPU." + + if c.is_static_input_scheme: + return False, "ScaledMMXLA requires dynamic activation scales." + + if not c.input_symmetric: + return False, "ScaledMMXLA requires symmetric activation scales." + + if not c.is_channelwise: + return False, "ScaledMMXLA requires channelwise weight scales" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # [out, in] (different than cutlass_scaled_mm) + weight = getattr(layer, self.w_q_name) + replace_parameter(layer, self.w_q_name, + torch.nn.Parameter(weight.data, requires_grad=False)) + + # WEIGHT SCALE + # XLA kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + + # [out_channel,] (different than cutlass_scaled_mm) + weight_scale = weight_scale.squeeze(-1) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # Only support symmetric dynamic activation quantization. + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + # Filter warning for cond usage in apply_weights. It is okay + # to specialize the graph since bias is not dynamic. + warnings.filterwarnings( + "ignore", + message= + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + ) + + def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + + def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + bias + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + out = torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) + # `quantized_matmul` output is fp32, cast it down to bf16 for perf + out = out.to(x.dtype) + # Explicitly capture control flow to make dynamo happy. + # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 + return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py new file mode 100644 index 0000000..e560467 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class BaseKVCacheMethod(QuantizeMethodBase): + """ + Quant method that adds `_k_scale` and `_v_scale` attributes to the + Attention layer to support loading those scaling factors from checkpoints. + The k/v_scale will be used to: + - quantize k/v_cache entries before saving them to the cache + - dequantize k/v_cache entries before fetching them from the cache + + :param quant_config: the appropriate QuantizationConfig + """ + + def __init__(self, quant_config: QuantizationConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """ + Create "weight" (aka q_scale, k_scale and v_scale) + for an attention layer. + """ + # Initialize the Q and KV cache scales to -1.0, an invalid value. + # If the q and k/v_scales appear in the checkpoint, it will be + # overwritten when loading weights. + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + # Initialize P = softmax(QK^T) scales + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError( + f"{self.__class__.__name__}.apply should not be called.") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + if layer.q_scale < 0.0: + logger.warning_once( + "Checkpoint does not provide a q scaling factor. " + "Setting it to k_scale. This only matters for " + "the flash-attn backend.") + layer._q_scale.copy_(k_scale) + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + logger.warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = lambda x: isinstance(x, float) or isinstance( + x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + if not is_singleton_float(q_scale) or not is_singleton_float( + prob_scale): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 + or prob_scale == 1.0): + logger.warning_once( + f"Using uncalibrated q_scale {q_scale} and/or prob_scale " + f"{prob_scale} with fp8 attention. This may cause accuracy " + "issues. Please make sure q/prob scaling factors are " + "available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py new file mode 100644 index 0000000..18d1c13 --- /dev/null +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) + +logger = init_logger(__name__) + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + lm_head_quantized: bool, + ) -> None: + super().__init__() + + # Group size for the quantization. + self.group_size = group_size + self.lm_head_quantized = lm_head_quantized + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin, but got group_size of " + f"{self.group_size}") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return (f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(group_size, lm_head_quantized) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" + or hf_quant_cfg.get("is_marlin_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "marlin") + + if is_marlin_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["MarlinLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return MarlinLinearMethod(self) + return None + + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B", qweight) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py new file mode 100644 index 0000000..9db8753 --- /dev/null +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -0,0 +1,747 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm._custom_ops import (cutlass_scaled_fp4_mm, + cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +QUANT_ALGOS = ["FP8", "NVFP4"] +KV_CACHE_QUANT_ALGOS = ["FP8"] + + +class ModelOptFp8Config(QuantizationConfig): + """Config class for ModelOpt FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + ) -> None: + super().__init__() + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change.") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + if quant_method not in QUANT_ALGOS: + raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" + " quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + is_checkpoint_fp8_serialized = ("FP8" in quant_method) + + return cls(is_checkpoint_fp8_serialized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + return ModelOptFp8LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer static quantization. + Supports loading FP8 checkpoints with static weight scale and + activation scale. Future support might be added for dynamic + scales. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn datatype + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + self.fp8_linear = Fp8LinearOp() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + # INPUT SCALE + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + + def process_weights_after_loading(self, layer: Module) -> None: + weight = layer.weight + max_w_scale = layer.weight_scale.max() + if not (layer.weight_scale == layer.weight_scale[0]).all(): + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias) + + +class ModelOptNvFp4Config(QuantizationConfig): + """Config class for ModelOpt FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool, + kv_cache_quant_algo: str, + exclude_modules: list[str], + group_size: int = 16, + ) -> None: + super().__init__() + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected ModelOpt NVFP4 checkpoint. Please note that" + " the format is experimental and could change in future.") + + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "modelopt_fp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half, torch.float8_e4m3fn] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + if quant_method not in QUANT_ALGOS: + raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" + " quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) + if ("group_size" and "kv_cache_quant_algo" + and "exclude_modules") not in quant_config: + raise ValueError("NVFP4 quantization requires group size and " + "kv_cache_quant_algo specified in " + "hf_quant_config.json") + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] + return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, + exclude_modules, group_size) + + def is_layer_excluded(self, prefix: str, exclude_modules: list): + import regex as re + for pattern in exclude_modules: + regex_str = pattern.replace('.', r'\.').replace('*', r'.*') + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + if (is_layer_skipped(prefix, self.exclude_modules) + or self.is_layer_excluded(prefix, self.exclude_modules)): + return UnquantizedLinearMethod() + return ModelOptNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return ModelOptNvFp4FusedMoE(self) + return None + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Union[ModelOptFp8Config, + ModelOptNvFp4Config]): + super().__init__(quant_config) + + +class ModelOptNvFp4LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + input_scale: torch.float32, scalar , + weight: NVFP4(represented as byte) Shape: [1, X, y/2] + weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, + weight_scale_2: torch.float32, scalar, + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + + if not self.cutlass_nvfp4_supported: + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if (input_size_per_partition % 16 != 0): + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + # The nvfp4 weight is still represented as + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + # Weight + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 items are packed in the input dimension + layer.output_size_per_partition, + layer.input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # Input Weight Scale + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + # Global Weight Scale + weight_scale_2 = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Per Block Weight Scale + weight_scale = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: Module) -> None: + + # global scales: + input_scale_2 = layer.input_scale.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + # Swizzle the weight blockscale. + # contracting dimension is input dimension + # block_size = 16; + assert (layer.weight_scale.shape[1] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Block scale must be represented as FP8-E4M3") + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + if self.use_marlin: + prepare_fp4_layer_for_marlin(layer) + del layer.alpha + del layer.input_scale + del layer.weight_scale_swizzled + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_marlin: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + output_dtype = x.dtype + output_shape = [x.shape[0], layer.weight.shape[0]] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + s_quant = 1 / layer.input_scale + x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + + # validate dtypes of quantized input, input block scale, + # weight and weight_blockscale + assert (x_fp4.dtype == torch.uint8) + assert (layer.weight.dtype == torch.uint8) + assert (x_blockscale.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.alpha.dtype == torch.float32) + + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, + layer.weight_scale_swizzled, layer.alpha, + output_dtype) + if bias is not None: + out = out + bias + return out.view(*output_shape) + + +class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): + """ + MoE Method for FP4 Quantization. + Args: + quant_config: NVFP4 Quant Config + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + self.use_marlin = False + + if not self.cutlass_nvfp4_supported: + if is_fp4_marlin_supported(): + self.use_marlin = True + else: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // + self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + + w13_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # GEMM 1 + if not torch.allclose(layer.w13_weight_scale_2[:, 0], + layer.w13_weight_scale_2[:, 1]): + logger.warning_once( + "w1_weight_scale_2 must match w3_weight_scale_2. " + "Accuracy may be affected.") + + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, + requires_grad=False) + + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + torch.float32) + layer.g1_alphas = Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False) + + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = self.swizzle_blockscale( + layer.w13_weight_scale) + + layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False) + + layer.w13_weight = Parameter(layer.w13_weight.data, + requires_grad=False) + + # GEMM 2 + layer.g2_alphas = Parameter( + (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter( + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + + layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, + requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + if self.use_marlin: + prepare_moe_fp4_layer_for_marlin(layer) + del layer.g1_alphas + del layer.g2_alphas + del layer.w13_input_scale_quant + del layer.w2_input_scale_quant + del layer.w13_blockscale_swizzled + del layer.w2_blockscale_swizzled + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + + if self.use_marlin: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_scale1=layer.w13_weight_scale_2, + global_scale2=layer.w2_weight_scale_2, + quant_type_id=scalar_types.float4_e2m1f.id, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map) + + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for ModelOptNvFp4FusedMoE.") + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py new file mode 100644 index 0000000..c00e14e --- /dev/null +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import torch +import os +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supports_layer) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.quantization.awq import ( + is_layer_skipped_awq) +from lmslim.layers.fused_moe.fuse_moe_int4 import fused_experts_w4a16 + +os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0') +os.environ['W4A16_MOE_LMSLIM'] = os.environ.get('W4A16_MOE_LMSLIM', '1') +if os.environ['W4A16_MOE_CUDA'] == '1': + from vllm.model_executor.layers.quantization.utils.fused_moe_cuda import fused_experts_cuda + +class MoeWNA16Config(QuantizationConfig): + """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" + + def __init__(self, linear_quant_method: str, weight_bits: int, + group_size: int, has_zp: bool, lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any]) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.has_zp = has_zp + self.bit8_pack_factor = 8 // self.weight_bits + self.lm_head_quantized = lm_head_quantized + self.linear_quant_method = linear_quant_method + self.full_config = full_config + self.use_marlin = False + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig) + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) + if self.linear_quant_method == "gptq": + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( + full_config) + elif self.linear_quant_method == "awq": + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + awq_min_capability = AWQConfig.get_min_capability() + if device_capability < awq_min_capability: + raise ValueError( + "The quantization method moe_wna16 + awq is not supported " + "for the current GPU. " + f"Minimum capability: {awq_min_capability}. " + f"Current capability: {device_capability}.") + self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( + full_config) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + if modules_to_not_convert is None: + self.modules_to_not_convert = [] + else: + self.modules_to_not_convert = modules_to_not_convert + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "moe_wna16" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": + linear_quant_method = cls.get_from_keys(config, ["quant_method"]) + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + if linear_quant_method == "gptq": + has_zp = not cls.get_from_keys(config, ["sym"]) + modules_to_not_convert = [] + elif linear_quant_method == "awq": + has_zp = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + return cls(linear_quant_method, weight_bits, group_size, has_zp, + lm_head_quantized, modules_to_not_convert, config) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) + if can_convert and user_quant == "moe_wna16": + return cls.get_name() + return None + + @classmethod + def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + desc_act = quant_config.get("desc_act") + + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + awq_min_capability = AWQConfig.get_min_capability() + + gptq_compatible = quant_method == "gptq" and \ + not desc_act and num_bits in [4, 8] + awq_compatible = quant_method == "awq" and num_bits == 4 and \ + device_capability >= awq_min_capability + + return gptq_compatible or awq_compatible + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + # Avoid circular import + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig) + from vllm.model_executor.layers.quantization.gptq import GPTQConfig + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) + if self.linear_quant_method == "gptq": + if self.use_marlin: + return GPTQMarlinConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return GPTQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + elif self.linear_quant_method == "awq": + if self.use_marlin and check_marlin_supports_layer( + layer, self.group_size): + return AWQMarlinConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + elif isinstance(layer, FusedMoE): + if is_layer_skipped_awq( + prefix, getattr(self, "modules_to_not_convert", [])): + return UnquantizedFusedMoEMethod(layer.moe_config) + return MoeWNA16Method(self) + return None + + +def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class MoeWNA16Method(FusedMoEMethodBase): + """Linear method for MOE WNA16 (W8A16/W4A16) quantization. + + Args: + quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. + """ + + def __init__(self, quant_config: MoeWNA16Config): + self.quant_config = quant_config + self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1' + self.use_w4a16_cuda = 0 + self.use_moe_lmslim = 0 + if self.use_w4a16_moe_sz: + self.use_w4a16_cuda = os.environ['W4A16_MOE_CUDA'] == '1' + self.use_moe_lmslim = os.environ['W4A16_MOE_LMSLIM'] == "1" + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.quant_config = self.quant_config + bit8_pack_factor = self.quant_config.bit8_pack_factor + group_size = self.quant_config.group_size + group_size_div_factor = 1 + + # make intermediate_size and hidden_size diviable by group_size + # we reduce the group size to ensure that + # and we would repeat the loaded_weight later + while intermediate_size_per_partition % group_size or \ + hidden_size % group_size: + group_size = group_size // 2 + group_size_div_factor *= 2 + assert group_size >= 32 + layer.group_size = group_size + layer.group_size_div_factor = group_size_div_factor + + strategy = FusedMoeWeightScaleSupported.GROUP.value + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": False + }) + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = MoeWNA16Method.get_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + w13_scales = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + if self.quant_config.has_zp: + w13_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size_per_partition // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + if self.quant_config.linear_quant_method == "gptq": + # some param are unused, but we need to init them in order to + # load weights + invalid_param_keys = ["w13_g_idx", "w2_g_idx"] + if not self.quant_config.has_zp: + invalid_param_keys += ["w13_qzeros", "w2_qzeros"] + for key in invalid_param_keys: + param = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int32), + requires_grad=False) + layer.register_parameter(key, param) + set_weight_attrs(param, extra_weight_attrs) + + def restore_qzeros_tensor(self, qzeros, qscales): + + low_bits = qzeros & 0x0F + high_bits = qzeros >> 4 + + zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1]) + zeors_int16 = zeors_tensor.to(torch.int16) + assert zeors_int16.shape == qscales.shape + + uint16_tensor1 = zeors_int16.view(torch.uint16) + uint16_tensor2 = qscales.view(torch.uint16) + + uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16 + uint32_tensor2 = uint16_tensor2.to(torch.int32) + + result_tensor = uint32_tensor1 + uint32_tensor2 + result_tensor =result_tensor.view(torch.uint32) + result_tensor = result_tensor.transpose(1, 2).contiguous() + return result_tensor + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.use_w4a16_moe_sz: + sz_tensor_1 = self.restore_qzeros_tensor(layer.w13_qzeros, layer.w13_scales) + sz_tensor_2 = self.restore_qzeros_tensor(layer.w2_qzeros, layer.w2_scales) + + layer.w13_scales = torch.nn.Parameter(sz_tensor_1,requires_grad=False) + layer.w2_scales = torch.nn.Parameter(sz_tensor_2,requires_grad=False) + layer.w13_qzeros = None + layer.w2_qzeros = None + torch.cuda.empty_cache() + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `MoeWNA16Method` yet.") + + assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate) + + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + + if self.use_moe_lmslim: + return fused_experts_w4a16( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_int4_w4a16=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + block_shape=[0, layer.group_size]) + + if self.use_w4a16_cuda: + m = topk_ids.shape[0] + if m <= 512: + return fused_experts_cuda(x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights, + topk_ids, + inplace=True, + use_fp8_w8a8=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=False, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=None, + w2_zp=None, + a1_scale=None, + a2_scale=None, + block_shape=[0, layer.group_size], + expert_map=expert_map) + + return fused_experts( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size], + use_nn_moe=False) + + @staticmethod + def get_weight_loader(layer, weight_loader): + + def convert_awq_tensor(tensor, tensor_type): + # convert awq qweight/qzeros to a standard format (assume int4) + # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) + # qzeros: (k // group_size, n // pack_factor_bit32) -> + # (n // pack_factor_bit8, k // group_size) + # pack_factor_bit32 = 32 // weight_bits + # pack_factor_bit8 = 8 // weight_bits + + # 0. suppose origin shape (a, b), dtype int32 + # 1. convert to uint8, shape (a, b) -> (a, 4 * b) + size0 = tensor.size(0) + tensor = tensor.view(torch.uint8) + + # 2. unpack to uint4 (only when weight_bits == 4) + # shape (a, 4 * b) -> (a, 4 * b, 2) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + + # 3. change order, see + # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py + # shape -> (a, 4 * b * pack_factor_bit8) + reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] + tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] + tensor = tensor.view(size0, -1) + + # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) + tensor = tensor.T.contiguous() + + # 5. repack (only when weight_bits == 4) + # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) + # qzeros shape -> (4 * b, a) + + if tensor_type == "qweight": + tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] + elif tensor_type == "qzeros": + tensor = tensor[1::2, :] * 16 + tensor[::2, :] + return tensor + + def convert_gptq_int4_qzeros(tensor): + tensor = tensor.view(torch.uint8) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + tensor = tensor + 1 + tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 + return tensor + + def moe_wna16_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: str, + expert_id: int, + return_success: bool = False): + if "g_idx" in weight_name: + return + if not layer.quant_config.has_zp and "qzeros" in weight_name: + return + + device = get_tp_group().device + tp_rank = get_tensor_model_parallel_rank() + loaded_weight = loaded_weight.to(device) + shard_size = layer.intermediate_size_per_partition + + # convert gptq and awq weight to a standard format + if layer.quant_config.linear_quant_method == "awq": + assert layer.quant_config.weight_bits == 4 + if "weight" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, + "qweight") + elif "zeros" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") + else: + loaded_weight = loaded_weight.T + elif layer.quant_config.linear_quant_method == "gptq": + assert layer.quant_config.weight_bits in [4, 8] + if "weight" in weight_name: + loaded_weight = loaded_weight.T.contiguous().view( + torch.uint8) + elif "zeros" in weight_name: + # add 1 to gptq qzeros to align with awq + loaded_weight = loaded_weight.view(torch.uint8) + if layer.quant_config.weight_bits == 4: + loaded_weight = convert_gptq_int4_qzeros( + loaded_weight).T + else: + loaded_weight = loaded_weight.T + 1 + else: + loaded_weight = loaded_weight.T + + # repeat the qzeros/scales to fit new group size + if layer.group_size_div_factor > 1 and \ + "qzeros" in weight_name or "scales" in weight_name: + loaded_weight = loaded_weight.repeat_interleave( + layer.group_size_div_factor, 1) + + if "w13_qzeros" in weight_name: + tensor = loaded_weight.view(layer.tp_size, -1, + loaded_weight.size(1))[tp_rank] + if shard_id == "w1": + param.data[expert_id, :shard_size // 2] = tensor + else: + param.data[expert_id, shard_size // 2:] = tensor + elif "w2_qzeros" in weight_name: + param.data[expert_id] = loaded_weight.view( + loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + else: + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + return_success = True + return return_success + return moe_wna16_weight_loader \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py new file mode 100644 index 0000000..8040236 --- /dev/null +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from importlib.util import find_spec +from typing import Any, Optional + +from torch.nn import Module + +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] + + +class AlwaysSupportedDtypes(list): + + def __contains__(self, item): + return True + + +class NeuronQuantConfig(QuantizationConfig): + """Int8 Quantization Config class for Neuron Backend.""" + + def __init__( + self, + dequant_dtype: str = "f16", + quantize_method: str = "vector_dynamic", + ) -> None: + super().__init__() + self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") + if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: + raise ValueError( + f"Neuron quantization datatype {self.quant_dtype} is not valid," + f" the quantization datatype should match one of the below " + f"types {SUPPORTED_QUANT_DTYPE_LIST}") + self.dequant_dtype = dequant_dtype + self.quantize_method = quantize_method + + def get_name(self) -> QuantizationMethods: + return "neuron_quant" + + def get_supported_act_dtypes(self) -> list[str]: + # Neuron implements custom handling logic for quantization support + return AlwaysSupportedDtypes() + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with Neuron Backend") + + @staticmethod + def get_config_filenames() -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig": + quantize_method = cls.get_from_keys(config, ["quantize_method"]) + dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) + return cls(dequant_dtype=dequant_dtype, + quantize_method=quantize_method) + + def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: + if find_spec("transformers_neuronx") is not None: + return self.get_quantization_config() + else: + raise NotImplementedError( + "Neuron Quantization is only supported through" + " transformers_neuronx.") + + def get_quantization_config(self): + from transformers_neuronx.config import QuantizationConfig + return QuantizationConfig(quant_dtype=self.quant_dtype, + dequant_dtype=self.dequant_dtype, + quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py new file mode 100644 index 0000000..32ba105 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp) +from vllm.platforms import current_platform + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class PTPCFp8Config(Fp8Config): + """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: Optional[list[str]] = None, + ) -> None: + if not current_platform.is_rocm(): + raise ValueError( + "ptpc_fp8 quantization is supported only on ROCm.") + + if not current_platform.has_device_capability(94): + raise ValueError( + "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 + ) + if activation_scheme == "static": + raise ValueError( + "ptpc_fp8 as of now only support dynamic quantization.") + + super().__init__(is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "ptpc_fp8" + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls(activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return PTPCFp8LinearMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + +class PTPCFp8LinearMethod(Fp8LinearMethod): + """Linear method for Per-Token and Per-Channel FP8 Quantization. + Only supports loading quantized BF16 model checkpoints with dynamic + activation scaling. To load FP16 model checkpoints, user must specify + to convert the FP16 model weight loading into BF16. + The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support float8_e4m3fnuz data type due to the limitation of + torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: PTPCFp8Config): + super().__init__(quant_config=quant_config) + # Force weight quantization + self.quant_config.is_checkpoint_fp8_serialized = False + self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, + use_per_token_if_dynamic=True) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + + assert layer.weight.data.dtype == torch.bfloat16, \ + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + # Quantize the weights. + qweight, weight_scale = ops.scaled_fp8_quant( + layer.weight, scale=None, use_per_token_if_dynamic=True) + + # Update the layer with the new values. + layer.weight = Parameter( + qweight.t(), requires_grad=False) # Pretranspose the weight + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py new file mode 100644 index 0000000..25978cb --- /dev/null +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) + +logger = init_logger(__name__) + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + + +class QQQConfig(QuantizationConfig): + """Config class for QQQ + + Reference: https://arxiv.org/pdf/2406.09904 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + is_sym: bool = True, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.is_sym = is_sym + + # Verify + if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: + raise ValueError( + f"QQQ does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"QQQ does not support group_size = {self.group_size}. " + f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: + raise ValueError( + f"QQQ does not support is_sym = {self.is_sym}. " + f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by QQQ kernels. + self.tile_size = MARLIN_QQQ_TILE + + # Min out_features dim + self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = MARLIN_QQQ_MAX_PARALLEL + + # Permutation length used by the QQQ kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "QQQConfig(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "qqq" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "QQQConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QQQLinearMethod"]: + if isinstance(layer, LinearBase): + return QQQLinearMethod(self) + return None + + +class QQQLinearMethod(LinearMethodBase): + """Linear method for QQQ. + + Args: + quant_config: The QQQ quantization config. + """ + + def __init__(self, quant_config: QQQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + s_channel = ChannelQuantScaleParameter(data=torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + weight_loader=weight_loader, + output_dim=1) + + if self.quant_config.group_size == -1: + s_group_data = torch.tensor( + [], + device="cuda", + dtype=torch.half, + ) + else: + s_group_data = torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, + ) + + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} + + if self.quant_config.group_size == -1: + s_group = BasevLLMParameter(**s_group_attr) + else: + s_group = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **s_group_attr) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B", qweight) + layer.register_parameter("s_channel", s_channel) + layer.register_parameter("s_group", s_group) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + s_ch = layer.s_channel + s_group = layer.s_group + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = s_ch.shape[1] + + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) + + output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/quark/__init__.py b/vllm/model_executor/layers/quantization/quark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py new file mode 100644 index 0000000..05dff4b --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import fnmatch +from typing import Any, Optional, cast + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 + QuarkMoEMethod) +from vllm.model_executor.layers.quantization.quark.schemes import ( + QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) +from vllm.model_executor.layers.quantization.quark.utils import ( + deep_compare, should_ignore_layer) +from vllm.platforms import current_platform + +__all__ = ["QuarkLinearMethod"] + +logger = init_logger(__name__) + + +class QuarkConfig(QuantizationConfig): + + def __init__(self, + quant_config: dict[str, Any], + kv_cache_group: Optional[list[str]] = None, + kv_cache_config: Optional[dict[str, Any]] = None, + pack_method: str = "reorder"): + super().__init__() + if kv_cache_group is None: + kv_cache_group = [] + self.quant_config = quant_config + self.kv_cache_group = kv_cache_group + self.kv_cache_config = kv_cache_config + self.pack_method = pack_method + + def get_linear_method(self) -> "QuarkLinearMethod": + return QuarkLinearMethod(self) + + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> QuantizationMethods: + return "quark" + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + exclude_layers = cast(list[str], self.quant_config.get("exclude")) + if should_ignore_layer(prefix, + ignore=exclude_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return QuarkLinearMethod(self) + if isinstance(layer, Attention): + return QuarkKVCacheMethod(self) + + if isinstance(layer, FusedMoE): + return QuarkMoEMethod.get_moe_method(self, + module=layer, + layer_name=prefix) + return None + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": + export_config = config.get("export") + if export_config is None: + raise ValueError("The export key should be included in " + "the configurations of Quark quantized model") + kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) + pack_method = cast(str, export_config.get("pack_method")) + + # In the export model of quark, the quantization configuration + # of kv_cache is stored in layer_quant_config. First, it is + # judged whether kv_cache_group exists, and then it is judged + # whether layer_quant_config has a quantization configuration + # that matches kv_cache. + if len(kv_cache_group) == 0: + kv_cache_config = None + else: + kv_cache_set = set(kv_cache_group) + layer_quant_config = cast(dict[str, Any], + config.get("layer_quant_config")) + layer_quant_names = list(layer_quant_config.keys()) + layer_quant_set = set(layer_quant_names) + + if not kv_cache_set.issubset(layer_quant_set): + raise ValueError("The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration.") + + q_configs = [ + cast(dict[str, Any], layer_quant_config.get(name)) + for name in kv_cache_group + ] + if not all( + deep_compare(q_config, q_configs[0]) + for q_config in q_configs): + raise ValueError( + "The quantization method used for kv_cache should " + "be the same, but the quantization method for the " + "kv_cache layer in the config is different.") + kv_cache_config = q_configs[0].get("output_tensors") + if kv_cache_config is None: + raise ValueError( + "The kv_cache quantization configuration is empty.") + + # Since we have already set kv_cache quantization configurations, + # we will remove the quantization configuration for the + # output_tensors corresponding to the kv_cache layer. + for q_config in q_configs: + q_config["output_tensors"] = None + + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(dict[str, Any], + layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + + return cls(quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method) + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def _check_scheme_supported(self, + min_capability: int, + error: bool = True) -> bool: + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False + + def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported + is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3" + and input_quant.get("dtype") == "fp8_e4m3") + is_static_weight = not weight_quant.get("is_dynamic") + is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") + in ["per_tensor", "per_channel"]) + + if not (is_fp8_dtype and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.get("is_dynamic"): + return True + + # Confirm activation scheme is supported. + is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") + return is_per_tensor_activation + + def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + is_int8_dtype = (weight_quant.get("dtype") == "int8" + and input_quant.get("dtype") == "int8") + + is_tensor = (weight_quant.get("qscheme") + in ["per_tensor", "per_channel"] + and input_quant.get("qscheme") == "per_tensor") + + is_static = (not weight_quant.get("is_dynamic") + and not input_quant.get("is_dynamic")) + + is_weight_symmetric = (weight_quant.get("symmetric") is True) + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + + def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + logger.debug("Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set") + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4" or input_quant.get( + "dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if weight_quant.get("qscheme") != "per_group" or input_quant.get( + "qscheme") != "per_group": + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32 or input_quant.get( + "group_size") != 32: + logger.debug( + "Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Weights need to use static quantization. + if weight_quant.get("is_dynamic") is True: + logger.debug( + "Quark model is not in MX-FP4 format: not weight static") + return False + + # Activations need to use dynamic quantization. + if input_quant.get("is_dynamic") is False: + logger.debug( + "Quark model is not in MX-FP4 format: not activation dynamic") + return False + + # Activations and weight scales need to be in e8m0 format. + if weight_quant.get("scale_format") != "e8m0" or input_quant.get( + "scale_format") != "e8m0": + logger.debug( + "Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + return True + + def _find_matched_config(self, layer_name: str, + module: torch.nn.Module) -> dict[str, Any]: + + proj_name = layer_name.split(".")[-1] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + shard_configs = [ + self._find_matched_config(shard_name, module) + for shard_name in shard_names + ] + if not all( + deep_compare(q_config, shard_configs[0]) + for q_config in shard_configs): + raise ValueError( + f"Found a different quantization configuration for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + return shard_configs[0] + else: + layer_quant_config = cast( + dict[str, Any], self.quant_config.get("layer_quant_config")) + for name_pattern in layer_quant_config: + if fnmatch.fnmatch(layer_name, name_pattern): + return layer_quant_config[name_pattern] + + layer_type = cast(str, type(module)) + layer_type_quant_config = cast( + dict[str, Any], + self.quant_config.get("layer_type_quant_config")) + if layer_type in layer_type_quant_config: + return layer_type_quant_config[layer_type] + + global_quant_config = cast( + dict[str, Any], self.quant_config.get("global_quant_config")) + return global_quant_config + + def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + if config.get("output_tensors") or config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with output_tensors " + "and bias quantized are not supported") + weight_config = cast(dict[str, Any], config.get("weight")) + input_config = cast(dict[str, Any], config.get("input_tensors")) + + if self._is_fp8_w8a8(weight_config, input_config): + is_fp8_w8a8_supported = self._check_scheme_supported( + QuarkW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + return QuarkW8A8Fp8(weight_config, input_config) + elif self._is_static_tensor_w8a8(weight_config, input_config): + weight_qscheme = cast(str, weight_config.get("qscheme")) + return QuarkW8A8Int8(qscheme=weight_qscheme, + is_static_input_scheme=True, + input_symmetric=input_config.get("symmetric")) + elif self._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFP4(weight_config, input_config) + + raise NotImplementedError("No quark compatible scheme was found. " + f"Weight config: {weight_config}, " + f"Input config: {input_config}") + + def get_scheme(self, layer: torch.nn.Module, + layer_name: str) -> "QuarkScheme": + + layer_quant_config = self._find_matched_config(layer_name, layer) + + # Find the quant_scheme + scheme = self._get_scheme_from_config(layer_quant_config) + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in quark. If this is the case, return its equivalent param name + expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + + # If no matches, return None + return None + + +class QuarkLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: QuarkConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class QuarkKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from quark checkpoints. + """ + + def __init__(self, quant_config: QuarkConfig): + self.validate_kv_cache_config(quant_config.kv_cache_config) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): + """ + Validator for the kv cache configuration. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_config: the quark kv cache scheme + """ + if kv_cache_config is None: + return + + dtype = kv_cache_config.get("dtype") + if dtype != "fp8_e4m3": + raise NotImplementedError( + "Currently supported kv cache quantization is " + f"dtype=fp8_e4m3, however received {dtype}") + + qscheme = kv_cache_config.get("qscheme") + if qscheme != "per_tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for quark KV cache. " + f"Expected qscheme: per_tensor, found qscheme: {qscheme}") diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py new file mode 100644 index 0000000..a040c43 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import torch + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] + + +class QuarkMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str) -> "QuarkMoEMethod": + layer_quant_config = quant_config._find_matched_config( + layer_name, module) + + if (layer_quant_config.get("output_tensors") + or layer_quant_config.get("bias")): + raise NotImplementedError("Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported") + weight_config = layer_quant_config.get("weight") + input_config = layer_quant_config.get("input_tensors") + + if quant_config._is_fp8_w8a8(weight_config, input_config): + return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + else: + raise RuntimeError("Unsupported FusedMoe scheme") + + +class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: dict[str, Any], input_config: dict[str, + Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_tensor" + and input_qscheme == "per_tensor"): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + if current_platform.is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py new file mode 100644 index 0000000..ec09d9b --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .quark_scheme import QuarkScheme +from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 +from .quark_w8a8_fp8 import QuarkW8A8Fp8 +from .quark_w8a8_int8 import QuarkW8A8Int8 + +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py new file mode 100644 index 0000000..c167e94 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["QuarkScheme"] + + +class QuarkScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by Quark. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py new file mode 100644 index 0000000..3c56251 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional + +import torch +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform + +__all__ = ["QuarkW4A4MXFP4"] + + +class QuarkW4A4MXFP4(QuarkScheme): + + def __init__(self, weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any]): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + self.emulate = not current_platform.supports_mx() + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + if self.emulate: + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + weight_quant_spec = QuantizationSpec.from_dict( + self.weight_quant_spec) + + weight_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, + float_dtype=self.out_dtype, + scale_shape=layer.weight_scale.shape, + zero_point_shape=None, + ) + weight_quantizer.scale.data = layer.weight_scale.data + + if not envs.VLLM_QUARK_EMU_MEM_OPT: + layer.weight = torch.nn.Parameter( + weight_quantizer(layer.weight.data).to(self.out_dtype), + requires_grad=False, + ) + else: + self.weight_quantizer = weight_quantizer + layer.weight_scale = None + + # This call is necessary to release the scales memory. + torch.cuda.empty_cache() + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.emulate: + if envs.VLLM_QUARK_EMU_MEM_OPT: + dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) + else: + dq_w = layer.weight + qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) + return F.linear(qdq_x, dq_w, bias) + else: + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py new file mode 100644 index 0000000..c7bc981 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Optional, cast + +import torch +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +__all__ = ["QuarkW8A8Fp8"] + + +class QuarkW8A8Fp8(QuarkScheme): + + def __init__(self, weight_config: dict[str, Any], + input_config: Optional[dict[str, Any]]): + self.weight_qscheme = cast(str, weight_config.get("qscheme")) + self.is_static_input_scheme: bool = False + self.input_qscheme: Optional[str] = None + if input_config is not None: + self.is_static_input_scheme = not cast( + bool, input_config.get("is_dynamic")) + self.input_qscheme = cast(str, input_config.get("qscheme")) + self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ + and self.input_qscheme == "per_channel") + self.fp8_linear = Fp8LinearOp( + use_per_token_if_dynamic=self.use_per_token_if_dynamic) + self.out_dtype = torch.get_default_dtype() + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.weight_qscheme == "per_tensor": + if current_platform.is_rocm(): + input_scale = getattr(layer, 'input_scale', None) + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + max_w_scale = layer.weight_scale + weight = layer.weight + + max_w_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=max_w_scale, + logical_widths=layer.logical_widths, + ) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.weight_qscheme == "per_channel": + weight = layer.weight + + if current_platform.is_fp8_fnuz(): + input_scale = getattr(layer, 'input_scale', None) + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + if self.use_per_token_if_dynamic: + weight_scale = weight_scale.view(-1, 1) + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError( + f"Unknown quantization scheme {self.weight_qscheme}") + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.weight_qscheme == "per_channel": + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes)), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.weight_qscheme == "per_tensor" + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return self.fp8_linear.apply(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py new file mode 100644 index 0000000..ae68d5b --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + + +class QuarkW8A8Int8(QuarkScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool], + input_symmetric: Optional[bool]): + self.qscheme = qscheme + self.is_static_input_scheme = is_static_input_scheme + self.input_symmetric = input_symmetric + + @classmethod + def get_min_capability(cls) -> int: + # turing and up + return 75 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + layer.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.qscheme == "per_channel"), + is_static_input_scheme=(self.is_static_input_scheme is True), + input_symmetric=(self.input_symmetric is True)) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.qscheme == "per_channel": + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes)), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + ChannelQuantZPParameter = ChannelQuantScaleParameter + weight_zero_point = ChannelQuantZPParameter( + data=torch.empty((sum(output_partition_sizes)), + dtype=torch.int8), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.qscheme == "per_tensor" + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + PerTensorZPParameter = PerTensorScaleParameter + weight_zero_point = PerTensorZPParameter( + data=torch.empty(len(output_partition_sizes), + dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_zero_point", weight_zero_point) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + input_zero_point = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in quark format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.register_parameter("weight_zero_point", None) + delattr(layer, 'weight_zero_point') + if self.input_symmetric: + layer.register_parameter("input_zero_point", None) + delattr(layer, 'input_zero_point') + + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py new file mode 100644 index 0000000..99f5ec1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Mapping +from types import MappingProxyType +from typing import Any, Optional + +import regex as re + + +def deep_compare(dict1: Any, dict2: Any) -> bool: + if type(dict1) is not type(dict2): + return False + if isinstance(dict1, dict): + if dict1.keys() != dict2.keys(): + return False + return all(deep_compare(dict1[k], dict2[k]) for k in dict1) + elif isinstance(dict1, list): + return set(dict1) == set(dict2) + else: + return dict1 == dict2 + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py new file mode 100644 index 0000000..6830971 --- /dev/null +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -0,0 +1,289 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright © 2025, Oracle and/or its affiliates. + +import os +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +logger = init_logger(__name__) +"""By default, use 8 bit as target precision, but it can be +overridden by setting the RTN_NUM_BITS envvar +""" +NUM_BITS = os.getenv('RTN_NUM_BITS', "8") +"""By default, use group size of 128 parameters, but it can be +overridden by setting the RTN_GROUP_SIZE envvar +""" +GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") + + +class RTNConfig(QuantizationConfig): + """Config class for RTN. + """ + + def __init__( + self, + weight_bits: int = int(NUM_BITS), + group_size: int = int(GROUP_SIZE), + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + if self.weight_bits != 4 and self.weight_bits != 8: + raise ValueError( + "Currently, only 4-bit or 8-bit weight quantization is " + f"supported for RTN, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"RTNConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "rtn" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "RTNConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["RTNLinearMethod"]: + if isinstance(layer, LinearBase): + return RTNLinearMethod(self) + return None + + +class RTNTensor: + """A wrapper over Tensor that enables quantization on-the-fly by + overloading the copy_ method. + """ + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, + quant_config: RTNConfig) -> None: + self.data = data + self.scale = scale + self.quant_config = quant_config + + def narrow(self, dim, start, length): + factor = 1 if self.quant_config.weight_bits == 8 else 2 + return RTNTensor( + self.data.narrow(dim, start // factor, length // factor), + self.scale.narrow(dim, start, length), self.quant_config) + + @property + def shape(self): + shape = self.data.shape + factor = 1 if self.quant_config.weight_bits == 8 else 2 + return torch.Size((shape[0] * factor, shape[1])) + + def copy_(self, loaded_weight: torch.Tensor) -> None: + qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), + self.quant_config.weight_bits, + self.quant_config.group_size) + + self.data.copy_(qweight) + self.scale.data.copy_(weight_scale) + + +class RTNParameter(Parameter): + """A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor) + when its data is accessed. We need this wrapper for the data loading phase + only, so we can intercept a weight copying function (torch.Tensor.copy_) + and apply quantization on-the-fly. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, scale: torch.Tensor, + quant_config: RTNConfig) -> None: + self.scale = scale + self.quant_config = quant_config + + @property + def data(self): + return RTNTensor(super().data, self.scale, self.quant_config) + + +class RTNLinearMethod(LinearMethodBase): + """Linear method for RTN. + + Args: + quant_config: The RTN quantization config. + """ + + def __init__(self, quant_config: RTNConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + num_groups_per_col = (input_size_per_partition // + self.quant_config.group_size + if self.quant_config.group_size != -1 else 1) + + scale = Parameter( + torch.empty(output_size_per_partition, + num_groups_per_col, + dtype=params_dtype), + requires_grad=False, + ) + factor = 1 if self.quant_config.weight_bits == 8 else 2 + + weight = RTNParameter(data=torch.empty(output_size_per_partition // + factor, + input_size_per_partition, + dtype=torch.int8), + scale=scale, + quant_config=self.quant_config) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) + + layer.register_parameter("scale", scale) + layer.output_size_per_partition = output_size_per_partition + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """torch.compile does not know how to deal with a Parameter subclass + (aka RTNParameter). As we don't really need RTNParameters for the + forward pass, we replace them with equivalent instances of Parameters. + """ + old_weight = layer.weight + assert isinstance(old_weight, RTNParameter) + data = old_weight.data.data + + delattr(layer, "weight") + + new_weight = Parameter(data=data, requires_grad=False) + layer.register_parameter("weight", new_weight) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.weight + scale = layer.scale + + weight = rtn_dequantize(qweight, scale) + out = F.linear(x, weight) + del weight + if bias is not None: + out.add_(bias) + + return out + + +def rtn_quantize(tensor: torch.Tensor, num_bits: int, + group_size: int) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor using per-group static scaling factor. + + Args: + tensor: The input tensor. + num_bits: Target precision for the result (supported values are + 8 or 4). + group_size: Quantization granularity. + If equal to -1, each row in the input tensor is treated + as one group. + """ + + q_range = 2**num_bits + num_groups = (tensor.shape[0] * tensor.shape[1] // + group_size if group_size != -1 else tensor.shape[0]) + """Calculate a scaling factor per input group. + """ + input_flat = tensor.reshape(num_groups, -1) + input_min = torch.min(input_flat, dim=1, keepdim=True)[0] + input_max = torch.max(input_flat, dim=1, keepdim=True)[0] + input_max_abs = torch.max(input_min.abs(), input_max.abs()) + scale = (input_max_abs * 2.0 / (q_range - 1)) + """Scale each input group, truncate and round to the nearest integer. + """ + scaled_input = input_flat / scale + scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1) + scaled_input = scaled_input.round() + + scale = scale.reshape(tensor.shape[0], -1).contiguous() + inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8) + inputs_q = inputs_q.contiguous() + + if num_bits == 4: + """Pack two 4-bit values into each byte. + """ + inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf) + inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1]) + inputs_q = inputs_q.contiguous() + + return inputs_q, scale + + +def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequantize a tensor using per-group static scaling factors. + + Args: + tensor: The input tensor. + scale: The tensor with per-group scale factors. + """ + + num_groups = scale.size(0) * scale.size(1) + input_dim, output_dim = tensor.shape + + num_bits = 8 if input_dim == scale.size(0) else 4 + if num_bits == 4: + input_dim *= 2 + + data = torch.empty((input_dim, output_dim), + dtype=scale.dtype, + device=tensor.device) + + if num_bits == 8: + data.copy_(tensor) + else: + """Unpack two 4-bit values from each byte. + """ + tensor = tensor.reshape(input_dim, output_dim // 2) + for i in range(2): + data[:, i::2] = (tensor << 4 * (1 - i)) >> 4 + """Scale each input group with its scaling factor. + """ + scale = scale.reshape(num_groups, -1) + data = data.reshape(num_groups, -1) + data = torch.mul(data, scale) + + input_deq = data.reshape((input_dim, output_dim)).contiguous() + return input_deq diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py new file mode 100644 index 0000000..a108152 --- /dev/null +++ b/vllm/model_executor/layers/quantization/schema.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file contains the Pydantic schemas for various quantization-related +parameters. When a relevant quantization technique is specified, these +parameters are loaded in the form of a JSON alongside the model weights +and augment the model with additional information needed for use of that +technique. The format of this JSON should be specified by one or more +schemas contained here. + +For example, when the KV cache is quantized to FP8-E4M3 (currently only +possible on ROCm), the model can be optionally augmented with KV cache +scaling factors. +""" + +from typing import Optional + +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator + + +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: dict[int, dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!") + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}.") + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}.") + for i in range(tp_size): + assert i in self.scaling_factor, ( + f"KV cache scales map for TP rank {i} not found.") + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}.") + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!") + return self diff --git a/vllm/model_executor/layers/quantization/slimquant_w4a8.py b/vllm/model_executor/layers/quantization/slimquant_w4a8.py new file mode 100644 index 0000000..6051932 --- /dev/null +++ b/vllm/model_executor/layers/quantization/slimquant_w4a8.py @@ -0,0 +1,398 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +from vllm.model_executor.utils import set_weight_attrs +from vllm.distributed import get_tensor_model_parallel_world_size +from torch.nn.parameter import Parameter +from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from lmslim.layers.gemm.int8_utils import ( + per_token_group_quant_int8, + per_token_quant_int8) +from vllm import _custom_ops as ops +from vllm.utils import W8a8GetCacheJSON + +import os +from vllm import _custom_ops as ops +from vllm import envs + +W8A8_TRITONJSON=W8a8GetCacheJSON() + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + scales= scale_a* scale_b.T + gemmout= torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32)) + output = (scales *gemmout).to(out_dtype) + if bias is not None: + output = output + bias + return output.to(out_dtype) + + +class SlimQuantW4A8Int8Config(QuantizationConfig): + """Config class for W8A8 Int8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "slimquant_w4a8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + + if isinstance(layer, LinearBase): + return SlimQuantW4A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return SlimQuantW4A8Int8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: SlimQuantW4A8Int8Config): + self.quantization_config = quantization_config + self.tritonsingleton= W8a8GetCacheJSON() + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + n=layer.weight.shape[0] + k=layer.weight.shape[1] + + if self.w8a8_strategy==1: + if {n,k} not in self.tritonsingleton.weight_shapes: + self.tritonsingleton.weight_shapes.append({n,k}) + json_file=self.tritonsingleton.get_w8a8json_name(n,k) + configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k) + + if configs_dict: + self.tritonsingleton.triton_json_dict.update(configs_dict) + + for key, value in configs_dict.items(): + m=int(key.split('_')[0]) + ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value) + else: + weight_data=layer.weight.data + _weight=weight_data.T.contiguous().reshape(n,-1) + layer.weight.data=_weight + + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_quant_args: Optional[list[torch.Tensor]] = None + ): + if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: + assert len(input_quant_args) == 2 + x_q, x_scale = input_quant_args + else: + x_q, x_scale = per_token_quant_int8(x) + + if self.w8a8_strategy==1: + m=x_q.shape[0] + k=x_q.shape[1] + n=layer.weight.shape[1] + + if len(W8A8_TRITONJSON.triton_json_dict)==0: + best_config=None + + elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict: + if m<=16: + m_=m + elif m<=64: + m_= (m + 3) & -4 #取值到最近的4的倍数 + elif m<=160: + m_=(m + 7) & -8 + + elif m<200: #256 + m_=160 + elif m<480: #512 + m_=256 + elif m<960: #1024 + m_=512 + elif m<2048: + m_=1024 + elif m<4096: + m_=2048 + elif m<6000: + m_=4096 + else: + m_=8192 + + best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"] + + else: + best_config=None + + #if best_config==None: + # print("m:{},n:{},k:{}".format(m,n,k)) + # print("config not found!") + + return ops.triton_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias,best_config=best_config) + elif self.w8a8_strategy==2: + return ops.cutlass_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias) + else: + return ops.rocblas_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype, + bias=bias) + + +class SlimQuantW4A8Int8MoEMethod: + """MoE method for W4A8INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + self.tritonsingleton= W8a8GetCacheJSON() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + tp_size = get_tensor_model_parallel_world_size() + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E=layer.w13_weight.shape[0] + N1=layer.w13_weight.shape[1] + N2=layer.w2_weight.shape[1] + K=N1//2 + if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes: + self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K]) + + TOPK= self.tritonsingleton.topk + + json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,use_int4_w4a8=True) + configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK) + + #warmup + if configs_dict: + self.tritonsingleton.triton_moejson_dict.update(configs_dict) + + layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **_ + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.") + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a8=True, + per_channel_quant=True, + activation=activation, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=use_nn_moe, + ) diff --git a/vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py b/vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py new file mode 100644 index 0000000..97ee514 --- /dev/null +++ b/vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py @@ -0,0 +1,275 @@ +from typing import Any, Callable, Dict, List, Optional +import os +import torch +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.model_executor.utils import set_weight_attrs +from vllm.distributed import get_tensor_model_parallel_world_size +from torch.nn.parameter import Parameter +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) +from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, + QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) +from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod + +try: + from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin +except Exception: + print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") + + +class MarlinMoeWorkspace: + """ + Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE. + global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device + """ + _instances = {} + def __new__(cls, device): + if device not in cls._instances: + instance = super().__new__(cls) + instance._initialized = False + cls._instances[device] = instance + return cls._instances[device] + + def __init__(self, device): + if self._initialized: + return + sms = torch.cuda.get_device_properties(device).multi_processor_count + self.workspace = torch.zeros( + 500, dtype=torch.int, device=device, requires_grad=False + ) + self.global_reduce_buffer = torch.zeros( + sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False + ) + self._initialized = True + + def get_buffers(self): + return self.workspace, self.global_reduce_buffer + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + scales= scale_a* scale_b.T + gemmout= torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32)) + output = (scales *gemmout).to(out_dtype) + if bias is not None: + output = output + bias + return output.to(out_dtype) + + +class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): + """Config class for W4A8 Int8 Quantization. + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "slimquant_w4a8_marlin" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig": + return cls() + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \ + and user_quant == "slimquant_w4a8_marlin": + return cls.get_name() + return None + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + + if isinstance(layer, LinearBase): + return SlimQuantW4A8Int8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return SlimQuantW4A8Int8MarlinMoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SlimQuantW4A8Int8MarlinMoEMethod: + """MoE method for W4A8INT8 Marlin. + Supports loading INT8 checkpoints with static weight scale and + dynamic/static activation scale. + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + tp_size = get_tensor_model_parallel_world_size() + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + w13_input_scale = None + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = None + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) + + layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False) + layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + use_nn_moe: Optional[bool] = False, + routed_scaling_factor: Optional[float] = None, + use_fused_gate: Optional[bool] = False, + **_ + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + use_fused_gate=use_fused_gate + ) + workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() + return fused_experts_impl_w4a8_marlin( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + workspace=workspace, + global_reduce_buffer=global_reduce_buffer, + inplace=True, + use_int4_w4a8=True, + per_channel_quant=True, + activation=activation, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + w1_scale=(layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + use_nn_moe=use_nn_moe, + ) diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py new file mode 100644 index 0000000..63b2ab6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +def should_skip(prefix: str, skip_modules: list[str]) -> bool: + """ + Robust skipping logic: + should_skip("model.model.layers.1.q_proj", + ["model.model.layers.1.q_proj"]) # True + should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True + should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True + should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True + should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False + """ + for s in skip_modules: + if prefix == s: + return True + if f".{s}." in f".{prefix}.": + return True + return False + + +class TorchAOConfig(QuantizationConfig): + """Config class for torchao.""" + + def __init__(self, + torchao_config, + skip_modules: Optional[list[str]] = None) -> None: + """ + # TorchAO quantization relies on tensor subclasses. In order, + # to enable proper caching this needs standalone compile + if is_torch_equal_or_newer("2.8.0.dev"): + os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1" + logger.info( + "Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1") + + # TODO: remove after the torch dependency is updated to 2.8 + if is_torch_equal_or_newer( + "2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"): + os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" + logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") + """ + super().__init__() + self.torchao_config = torchao_config + self.skip_modules = skip_modules or [] + + def __repr__(self) -> str: + return f"TorchAOConfig({self.torchao_config})" + + def get_name(self) -> QuantizationMethods: + return "torchao" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.float32, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @staticmethod + def get_config_filenames() -> list[str]: + return ["config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": + """Create the quant config from an hf model config""" + try: + from torchao.core.config import config_from_dict + except ImportError as err: + raise ImportError( + "Please install torchao>=0.10.0 via " + "`pip install torchao>=0.10.0` to use torchao quantization." + ) from err + + hf_config = cls.get_from_keys_or(config, ["quant_type"], None) + assert hf_config is not None, "quant_type must be specified" + assert len(hf_config) == 1 and "default" in hf_config, ( + "Expected only one key 'default' in quant_type dictionary") + quant_type = hf_config["default"] + ao_config = config_from_dict(quant_type) + + # Adds skipped modules defined in "modules_to_not_convert" + skip_modules = config.get("modules_to_not_convert", []) or [] + + # Adds skipped modules defined in "module_fqn_to_config" + _data = quant_type.get("_data", {}) + if not isinstance(_data, dict): + _data = {} + + module_fqn = _data.get("module_fqn_to_config", {}) + if not isinstance(module_fqn, dict): + module_fqn = {} + + for layer, layer_cfg in module_fqn.items(): + if layer_cfg is None: + skip_modules.append(layer) + + return cls(ao_config, skip_modules) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if not isinstance(layer, LinearBase): + return None + + from torchao.quantization import ModuleFqnToConfig + + if should_skip(prefix, self.skip_modules): + return UnquantizedLinearMethod() + + module_fqn = prefix + if isinstance(self.torchao_config, ModuleFqnToConfig): + module_fqn_to_config = self.torchao_config.module_fqn_to_config + c = module_fqn_to_config.get( + module_fqn) or module_fqn_to_config.get("_default", None) + if c is not None: + current_torchao_config = TorchAOConfig(c, self.skip_modules) + return TorchAOLinearMethod(current_torchao_config) + else: + return UnquantizedLinearMethod() + + return TorchAOLinearMethod(self) + + def get_scaled_act_names(self) -> list[str]: + return [] + + +def torchao_quantize_param_data(param: torch.Tensor, + torchao_config: Any) -> torch.nn.Parameter: + """Quantize a Tensor with torchao quantization specified by torchao_config + + Args: + `param`: weight parameter of the linear module + `torchao_config`: type of quantization and their arguments we want to + use to quantize the Tensor + """ + from torchao.core.config import AOBaseConfig + from torchao.quantization import quantize_ + + assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" + """ + Avoid real weight allocation for faster load, since we will + end up setting it to param. + """ + with torch.device("meta"): + dummy_linear = torch.nn.Linear(param.shape[1], + param.shape[0], + bias=False) + + dummy_linear.weight = param + quantize_(dummy_linear, torchao_config) + return dummy_linear.weight + + +class TorchAOLinearMethod(LinearMethodBase): + """Linear method for torchao. + + Args: + torchao_config: The torchao quantization config, a string + that encodes the type of quantization and all relevant arguments. + """ + + def __init__(self, quant_config: TorchAOConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + weight = torchao_quantize_param_data(weight, + self.quant_config.torchao_config) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return F.linear(x, layer.weight, bias) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py new file mode 100644 index 0000000..83c8a98 --- /dev/null +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import ModelWeightParameter + +ACTIVATION_SCHEMES = ["none"] + + +class Int8TpuConfig(QuantizationConfig): + """Int8 Quantization Config class for TPU Backend.""" + + def __init__( + self, + activation_scheme: str = "none", + ) -> None: + super().__init__() + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + def get_name(self) -> QuantizationMethods: + return "tpu_int8" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with TPU Backend") + + @staticmethod + def get_config_filenames() -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme=activation_scheme) + + def get_quant_method(self, layer: Module, + prefix: str) -> Optional["TPUInt8LinearMethod"]: + if isinstance(layer, LinearBase): + return TPUInt8LinearMethod(self) + return None + + +class TPUInt8LinearMethod(LinearMethodBase): + """Int8 Linear method for TPU Quant. """ + + def __init__(self, quant_config: Int8TpuConfig): + self.quant_config = quant_config + + def create_weights(self, layer: Module, input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + def _quantize_weight( + self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + weight_dtype = weight.dtype + weight = weight.cpu().to(torch.float32) + n_bit = 8 + eps = 1e-5 + max_int = 2**(n_bit - 1) - 1 + min_int = -(2**(n_bit - 1)) + max_val = weight.abs().amax(dim=-1, keepdim=True) + max_val = max_val.clamp(min=eps) + qscale = max_val / max_int + qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int, + max_int).to(torch.int8) + qscale = qscale.squeeze().to(weight_dtype) + return qweight, qscale + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) + device = layer.weight.device + qweight, qscale = self._quantize_weight(layer.weight) + qweight = qweight.to(device) + qscale = qscale.to(device) + layer.weight = Parameter(qweight, requires_grad=False) + layer.scale = Parameter(qscale, requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + try: + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + except ImportError as err: + raise ImportError( + "Please install torch_xla by following the instructions at " + "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 + "to run vLLM on TPU.") from err + weight = layer.weight + scale = layer.scale + out = torch.ops.xla.quantized_matmul(x, weight, scale) + if bias is not None: + out = out + bias + return out diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000..6ad56ba --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/allspark_utils.py b/vllm/model_executor/layers/quantization/utils/allspark_utils.py new file mode 100644 index 0000000..1992b4d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/allspark_utils.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD = 1024 +ALLSPARK_SUPPORTED_QUANT_TYPES = [scalar_types.uint8b128] +ALLSPARK_AMPERE_N_ALIGN = 16 +ALLSPARK_AMPERE_K_ALIGN = 16 + + +def check_allspark_supported_dtype_shape(input_size_per_partition: int, + output_size_per_partition: int, + group_size: int, + weight_dtype: ScalarType, + act_dtype: torch.dtype): + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + # For Ampere GPU + if device_capability >= 80 and device_capability < 90: + if group_size != -1: + return False, \ + "For Ampere GPU, AllSpark does not support group_size "\ + f"= {group_size}. Only group_size = -1 are supported." + + if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: + return False, "For Ampere GPU, AllSpark does not support "\ + f"quant type ({weight_dtype}). Only quant type "\ + f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." + + if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ + or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: + return False, \ + "AllSpark needs input_size_per_partition % "\ + f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ + f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ + "for Ampere GPU optimized kernels." + + if act_dtype != torch.float16 and act_dtype != torch.bfloat16: + return False, \ + "AllSpark only supports act_dtype = float16 or bfloat16,"\ + f"for Ampere GPU, but got act_dtype = {act_dtype}." + else: + return False, "AllSpark currently does not support "\ + f"device_capability = {device_capability}." + + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py new file mode 100644 index 0000000..82ee3ed --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +MINIMUM_BITBLAS_VERSION = "0.1.0" + +BITBLAS_MIN_WEIGHT_SIZE_N = 16 +BITBLAS_MIN_WEIGHT_SIZE_K = 16 +GPTQ_BITBLAS_MAX_PARALLEL = 16 + +BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# For dynamic shape code generation +BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024] +# If want to enable high performance for contiguous batching +# Please use the following values +BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024] + +BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] +BITBLAS_SUPPORTED_SYM = [False, True] + + +# Determines the supported quantization types for BitBLAS based on the +# device's capability and whether zero-point (zp) is used. +def query_bitblas_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 70: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_bitblas_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_bitblas_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"BitBLAS does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES): + return (False, f"BitBLAS does not support group_size = {group_size}. " + f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " + "are supported.") + + # Finally, check if bitblas is installed + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError("bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError: + return False, "BitBLAS is not installed." + + return True, None + + +def check_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: + try: + verify_bitblas_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> + (bits * i)) & 0xF + if not is_gptq_v2: + return unpacked_zeros + 1 + return unpacked_zeros + + +def unpack_gptq_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> + (bits * i)) + + return torch.bitwise_and(unpacked_weight, 2**bits - 1) diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e9a50e1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..119969d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..119969d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3e8ebf3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2bb5b45 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6496a38 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6e2aeee --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b0f9442 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b3bf9ea --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7e52ab6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7e52ab6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bee8d03 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9da876d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3618053 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0a1a252 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..46a982f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9696611 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d6279a1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..defaacb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ecc2fda --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ecc2fda --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3bc0036 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..310dff4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..035ec02 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..206c8a2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8b49f27 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..edc2353 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..987c8f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..108af31 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..108af31 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..43b5bdb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bffa749 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..851bc9f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f96f127 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d1227c2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fe3e18c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b3ed43a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..abd1915 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..abd1915 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e4d5b2d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..137b9dd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..77ba0d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1c61451 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..38cac46 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8e6ebe2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..63e661c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..459062e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1225d84 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..03e8235 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bb61d83 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bb61d83 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d44e384 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c559a69 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cf35403 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8ec2005 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..65840aa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1a457b9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..574cf49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..574cf49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0a5d7bf --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4e120d6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..eccb86a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..125fe36 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4415cc9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bfaf93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bfaf93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cb91a27 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..88af484 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5c29874 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..dd06972 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..125fe36 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7c039b4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c2bd478 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c2bd478 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4990268 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..18afdd9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7febe3d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56b939e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..51d10bb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1480e09 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..63d9a0b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f5fdec3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6bd350c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5c604b9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..75906ad --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..75906ad --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7fa398c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f15d8f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4d25ae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fdc6437 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..fdc6437 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9d7658b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cd3e078 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2b9f0d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9d5a329 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7f449db --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..634c1bf --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7eaa7d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7eaa7d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..03dba5a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..96e1594 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d979c6b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5ffd367 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..be93dfe --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..19452df --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3382554 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3382554 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9a5ff48 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6eb22de --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..eabc423 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..84ef35e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e6d9107 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c9d18c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c9d18c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c746e70 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0b4746c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..386928d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..51e237b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8ec2005 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..202acf2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6280219 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..983525f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,18 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..11a9bce --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c298da8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56a766c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56a766c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..386ee59 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..60df5e3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..40c01c0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4f1747b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c6fd365 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..53bbaca --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..cb993c8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f250d3f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f250d3f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ffe67dc --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2a17e16 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..160f12e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b259993 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e5c4a1d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a71ab88 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..56d3e1f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bbd4df4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bbd4df4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..eda96e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..bd0767b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..2bf5eb2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..29f7651 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6db1385 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9cdff13 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bb8e87 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..7bb8e87 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1a47cae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8dd5ae5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9c908e8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0a1e14c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6d1a8b5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e77abaf --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..15b1c93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0cf6a47 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..01327b2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..6f9bd75 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f050b75 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f050b75 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..12eea5f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9db9dae --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f78e706 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8ff12e6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..365f8d0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..f080ea5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4532f93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..0cf6a47 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e9bf044 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c7122d3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4a3ccc0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..4a3ccc0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..1d3ce5c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..ca7f32b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c37aced --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..5acea24 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d962889 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3cea21b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..24ef112 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..24ef112 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3ab5796 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..58cdd93 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d6bef7f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b72e037 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4b08ea --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..a8141f5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c911a8e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c911a8e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..3cb7eaa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..8df6e4b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..293adce --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..9d7edc3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..c9566d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d86b349 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..d86b349 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..e471687 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4c3249 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 0000000..b4c3249 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 16, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "1536": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "3072": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "kpack": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py new file mode 100644 index 0000000..b8d3a3d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -0,0 +1,653 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import functools +import json +import os +from typing import Any, Callable, Optional, Union, List + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_dequantize) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm + +logger = init_logger(__name__) + + +def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: + if isinstance(x, torch.Tensor): + x = x.dtype + return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz + + +def cutlass_scaled_mm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + return ops.cutlass_scaled_mm(A, + B.T, + out_dtype=output_dtype, + scale_a=As, + scale_b=Bs.T) + + +def rocm_aiter_gemm_w8a8_blockscale_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + import aiter as rocm_aiter + + return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) + + +def rocm_aiter_gemm_w8a8_blockscale_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8_blockscale", + op_func=rocm_aiter_gemm_w8a8_blockscale_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def dispatch_w8a8_blockscale_func( + use_cutlass: bool, use_aiter_and_is_supported: bool +) -> Callable[[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + List[int], + torch.dtype, +], torch.Tensor]: + if use_cutlass: + return cutlass_scaled_mm + if (use_aiter_and_is_supported): + return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale + return w8a8_block_fp8_matmul + + +def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): + """ + Check if DeepGEMM should be used based on the output dtype and weight shape. + DeepGEMM is only supported for bfloat16 output dtype and weights with shape + divisible by 128. + """ + + return (current_platform.is_cuda() + and current_platform.is_device_capability(90) and has_deep_gemm() + and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + + +# TODO fix ROCm->Triton custom path: +# https://github.com/vllm-project/vllm/issues/14397 +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + output_dtype = input.dtype + + if should_use_deepgemm(output_dtype, weight): + + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + ) + + import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( + q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=output_dtype) + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) + + if current_platform.is_cuda(): + if current_platform.has_device_capability(100): + + use_cutlass = cutlass_block_fp8_supported and ( + cdiv(weight.shape[0], 128) == weight_scale.shape[0] + and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) + else: + # TODO: update this after switching to public sm90 block scale gemm + # as it also supports weight.shape % 128 != 0 + use_cutlass = cutlass_block_fp8_supported and ( + weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + else: + use_cutlass = False + + w8a8_blockscale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported) + if use_cutlass: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=use_cutlass) + + output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, + block_size, input.dtype) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def apply_w8a8_block_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, + use_aiter_and_is_supported: bool = False, +) -> torch.Tensor: + output_shape = [*input.shape[:-1], weight.shape[0]] + return torch.empty(output_shape, dtype=input.dtype, device=input.device) + + +if not current_platform.is_cpu(): + direct_register_custom_op( + op_name="apply_w8a8_block_fp8_linear", + op_func=apply_w8a8_block_fp8_linear, + mutates_args=[], + fake_impl=apply_w8a8_block_fp8_linear_fake, + ) + + +def input_to_float8( + x: torch.Tensor, + dtype: Optional[torch.dtype] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values " + "with tensor-wise quantization.""" + dtype = current_platform.fp8_dtype() if dtype is None else dtype + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise + quantization. The inputs are block-wise quantization tensor `x_q_block`, + block-wise quantization scale and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise + quantization scale. Note only float8 is supported for now. + """ + x_dq_block = scaled_dequantize(x_q_block, x_s) + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + # Ensure offset calculations use int64 to prevent overflow + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * + group_size) + y_ptr += y_ptr_offset + + y_q_ptr_offset = g_id.to(tl.int64) * group_size + y_q_ptr += y_q_ptr_offset + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + y_row_stride, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + groups_per_row = y_num_columns // group_size + + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + # Ensure offset calculations use int64 to prevent overflow + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * + group_size) + y_ptr += y_ptr_offset + + y_q_ptr_offset = g_id.to(tl.int64) * group_size + y_q_ptr += y_q_ptr_offset + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + # Ensure offset calculation uses int64 for y_s_ptr + y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to( + tl.int64) + y_s_ptr += y_s_ptr_offset + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, + out_q: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + column_major_scales: Outputs scales in column major. + out_q: Optional output tensor. If not provided, function will create. + Returns: + tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + dtype = current_platform.fp8_dtype() if dtype is None else dtype + assert (x.shape[-1] % group_size == 0), ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + assert out_q is None or out_q.shape == x.shape + x_q = out_q + if x_q is None: + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size, ) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, + dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@functools.lru_cache +def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, + block_k: int) -> Optional[dict[int, Any]]: + """ + Return optimized configurations for the w8a8 block fp8 kernel. + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + device_name = current_platform.get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block FP8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + "Using default W8A8 Block FP8 kernel config. Performance might " + "be sub-optimal! Config file not found at %s", + config_file_path, + ) + return None + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # Get the optimal config if there is one + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] + # BLOCK_SIZE_K must be divisible by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + } + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C diff --git a/vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py b/vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py new file mode 100644 index 0000000..19e8ad4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger + +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from grouped_gemm_int4 import moe_gemm_w4a16 +from grouped_gemm_int4.ops import permute as permute_topK, unpermute as unpermute_topK +import torch.nn.functional as F +logger = init_logger(__name__) +device_name = current_platform.get_device_name() + +def config_cuda(M): + + k100ai_gemm1_m_to_mode_dict = { + 1: 'M16N16K256NN1NW8B240', + 2: 'M16N16K256NN1NW16B360', + 3: 'M16N16K256NN1NW16B360', + 4: 'M16N16K256NN1NW16B360', + 5: 'M16N16K256NN1NW16B360', + 6: 'M16N16K256NN1NW16B360', + 7: 'M16N16K256NN1NW16B360', + 8: 'M16N16K256NN1NW16B240', + 9: 'M16N16K256NN1NW16B360', + 10: 'M16N16K256NN1NW16B360', + 11: 'M16N16K256NN1NW16B360', + 12: 'M16N16K256NN1NW16B360', + 13: 'M16N16K256NN1NW16B360', + 14: 'M16N16K256NN1NW16B360', + 15: 'M16N16K256NN1NW16B360', + 16: 'M16N16K256NN1NW16B240', + 17: 'M16N16K256NN1NW16B360', + 18: 'M16N16K256NN1NW16B360', + 19: 'M16N16K256NN1NW16B360', + 20: 'M16N16K256NN1NW16B360', + 21: 'M16N16K256NN1NW16B360', + 22: 'M16N16K256NN1NW16B360', + 23: 'M16N16K256NN1NW16B360', + 24: 'M16N16K256NN1NW16B360', + 25: 'M16N16K256NN1NW16B360', + 26: 'M16N16K256NN1NW16B360', + 27: 'M16N16K256NN1NW16B240', + 28: 'M16N16K256NN1NW16B360', + 29: 'M16N16K256NN1NW16B360', + 30: 'M16N16K256NN1NW16B360', + 31: 'M16N16K256NN1NW16B240', + 32: 'M16N16K256NN1NW16B360', + 64: 'M16N16K256NN1NW16B360', + 128: 'M16N16K256NN1NW16B240', + 256: 'M16N16K256NN1NW16B360', + 512: 'M16N16K128NN1NW8B120', + 768: 'M16N32K128NN1NW16B100', + 1024: 'M16N32K128NN1NW16B120', + } + + + k100ai_gemm2_m_to_mode_dict = { + 1: 'M16N32K256NN8NW1B240', + 2: 'M16N32K256NN8NW1B360', + 3: 'M16N32K256NN8NW1B360', + 4: 'M16N32K256NN4NW1B360', + 5: 'M16N32K256NN4NW1B360', + 6: 'M16N32K256NN4NW1B360', + 7: 'M16N32K256NN4NW1B360', + 8: 'M16N32K256NN8NW1B360', + 9: 'M16N32K256NN8NW1B240', + 10: 'M16N32K256NN8NW1B240', + 11: 'M16N32K256NN8NW1B240', + 12: 'M16N32K256NN8NW1B240', + 13: 'M16N32K256NN4NW1B360', + 14: 'M16N32K256NN16NW1B360', + 15: 'M16N32K256NN16NW1B360', + 16: 'M16N32K256NN16NW1B360', + 17: 'M16N32K256NN8NW1B240', + 18: 'M16N32K256NN8NW1B240', + 19: 'M16N32K256NN16NW1B360', + 20: 'M16N32K256NN16NW1B360', + 21: 'M16N32K256NN16NW1B360', + 22: 'M16N32K256NN16NW1B360', + 23: 'M16N32K256NN16NW1B360', + 24: 'M16N32K256NN16NW1B240', + 25: 'M16N32K256NN16NW1B360', + 26: 'M16N32K256NN16NW1B360', + 27: 'M16N32K256NN16NW1B360', + 28: 'M16N32K256NN16NW1B360', + 29: 'M16N64K256NN4NW1B240', + 30: 'M16N32K256NN16NW1B360', + 31: 'M16N32K256NN16NW1B360', + 32: 'M16N32K256NN16NW1B240', + 64: 'M16N32K256NN16NW1B360', + 128: 'M16N64K256NN4NW1B240', + 256: 'M16N32K256NN16NW1B360', + 512: 'M16N64K256NN8NW1B120', + 768: 'M16N64K256NN16NW1B360', + 1024: 'M16N64K256NN16NW1B360', + } + + + bw_gemm1_m_to_mode_dict = { + 1: 'M16N16K256NN1NW8B360', + 2: 'M16N16K256NN1NW4B360', + 3: 'M16N32K256NN1NW8B240', + 4: 'M16N32K256NN1NW4B360', + 5: 'M16N64K256NN1NW4B240', + 6: 'M16N32K256NN1NW8B240', + 7: 'M16N32K256NN1NW8B360', + 8: 'M16N64K256NN1NW4B360', + 9: 'M16N64K256NN1NW4B240', + 10: 'M16N32K256NN1NW8B240', + 11: 'M16N64K256NN1NW4B240', + 12: 'M16N64K256NN1NW4B360', + 13: 'M16N32K256NN1NW8B240', + 14: 'M16N32K256NN1NW8B240', + 15: 'M16N32K256NN1NW8B240', + 16: 'M16N64K256NN1NW4B360', + 17: 'M16N32K256NN1NW8B240', + 18: 'M16N64K256NN1NW4B240', + 19: 'M16N32K256NN1NW8B240', + 20: 'M16N32K256NN1NW8B240', + 21: 'M16N32K256NN1NW8B240', + 22: 'M16N32K256NN1NW8B240', + 23: 'M16N32K256NN1NW8B240', + 24: 'M16N64K256NN1NW4B240', + 25: 'M16N32K256NN1NW8B240', + 26: 'M16N32K256NN1NW8B240', + 27: 'M16N32K256NN1NW8B240', + 28: 'M16N32K256NN1NW8B240', + 29: 'M16N64K256NN1NW4B240', + 30: 'M16N64K256NN1NW4B240', + 31: 'M16N32K256NN1NW8B240', + 32: 'M16N64K256NN1NW4B240', + 64: 'M16N32K256NN1NW8B240', + 128: 'M16N64K256NN1NW4B240', + 256: 'M16N64K256NN1NW4B240', + 512: 'M16N64K256NN1NW4B240', + 768: 'M16N64K256NN1NW4B240', + 1024: 'M16N64K256NN1NW4B240', + } + + + bw_gemm2_m_to_mode_dict = { + 1: 'M16N32K128NN8NW1B240', + 2: 'M16N64K256NN8NW1B240', + 3: 'M16N64K256NN4NW1B360', + 4: 'M16N64K256NN16NW1B240', + 5: 'M16N64K256NN8NW1B240', + 6: 'M16N64K256NN8NW1B240', + 7: 'M16N64K256NN16NW1B240', + 8: 'M16N64K256NN8NW1B240', + 9: 'M16N64K256NN16NW1B360', + 10: 'M16N64K256NN8NW1B240', + 11: 'M16N64K256NN16NW1B360', + 12: 'M16N64K256NN8NW1B240', + 13: 'M16N64K256NN16NW1B240', + 14: 'M16N64K256NN16NW1B360', + 15: 'M16N64K256NN16NW1B240', + 16: 'M16N64K256NN16NW1B240', + 17: 'M16N64K256NN8NW1B240', + 18: 'M16N64K256NN8NW1B240', + 19: 'M16N64K256NN16NW1B240', + 20: 'M16N64K256NN8NW1B240', + 21: 'M16N64K256NN16NW1B240', + 22: 'M16N64K256NN16NW1B360', + 23: 'M16N64K256NN16NW1B360', + 24: 'M16N64K256NN16NW1B240', + 25: 'M16N64K256NN8NW1B240', + 26: 'M16N64K256NN16NW1B240', + 27: 'M16N64K256NN16NW1B240', + 28: 'M16N64K256NN16NW1B240', + 29: 'M16N64K256NN16NW1B240', + 30: 'M16N64K256NN16NW1B240', + 31: 'M16N64K256NN8NW1B240', + 32: 'M16N64K256NN16NW1B240', + 64: 'M16N64K256NN16NW1B240', + 128: 'M16N64K256NN16NW1B240', + 256: 'M16N64K256NN16NW1B240', + 512: 'M16N64K256NN16NW1B240', + 768: 'M16N64K256NN16NW1B240', + 1024: 'M16N64K256NN16NW1B240', + } + + reference_points = [32, 64, 128, 256, 512, 1024] + NearestM = -1 + + if M <= 32: + NearestM = M + else: + NearestM = min(reference_points, key=lambda x: abs(x - M)) + + if device_name == "K100_AI": + mode_1 = k100ai_gemm1_m_to_mode_dict.get(NearestM, k100ai_gemm1_m_to_mode_dict[32]) + mode_2 = k100ai_gemm2_m_to_mode_dict.get(NearestM, k100ai_gemm2_m_to_mode_dict[32]) + else: + mode_1 = bw_gemm1_m_to_mode_dict.get(NearestM, bw_gemm1_m_to_mode_dict[32]) + mode_2 = bw_gemm2_m_to_mode_dict.get(NearestM, bw_gemm2_m_to_mode_dict[32]) + + return mode_1, mode_2 + +def fused_experts_cuda(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + expert_map: Optional[torch.Tensor] = None,): + if inplace: + fused_experts_impl_cuda(hidden_states, w1, w2, topk_weights, topk_ids, True, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, + expert_map) + + return hidden_states + else: + return fused_experts_impl_cuda(hidden_states, w1, w2, topk_weights, topk_ids, False, + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, + expert_map) + + +def fused_experts_impl_cuda(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, + expert_map: Optional[torch.Tensor] = None,): + # Check constraints. + + assert hidden_states.shape[1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + + + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + num_tokens, _ = hidden_states.shape + + E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + + M = num_tokens + topk = topk_ids.shape[1] + mode_1, mode_2 = config_cuda(M) + # config = get_config_func(M) + + intermediate_cache1 = torch.empty((M, topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, 16, E, expert_map, hidden_states.shape[0])) + moe_gemm_w4a16.gemm1_w4a16(sorted_token_ids, # sorted_token_ids.to(torch.uint16) + hidden_states, # hidden_states + w1, # w1 + intermediate_cache1, # gemm1_out + num_tokens_post_padded, # 实际专家数 + expert_ids, # expert_id_vec + w1_scale, # scale_zero + block_shape[1], # group_size + topk=topk, # topk + mode=mode_1) # mode=gemm1_mode + + torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + # return intermediate_cache2 + moe_gemm_w4a16.gemm2_w4a16(sorted_token_ids, # sorted_token_ids.to(torch.uint16) + intermediate_cache2, # hidden_states + w2, # w2 + intermediate_cache3, # gemm2_out + num_tokens_post_padded, + expert_ids, # expert_id_vec + w2_scale, # scale_zero + topk_weights, # topk_weights + block_shape[1], # group_size + topk=topk, # topk + mode=mode_2) # mode=gemm2_mode + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states) + + + return out_hidden_states + + diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py new file mode 100644 index 0000000..db82b0d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy +from typing import Optional, Union + +import regex as re +import torch + +from vllm.config import QuantizationConfig +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, UnquantizedEmbeddingMethod) + + +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", + config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", + config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", + config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}") + + config.quant_type = config.TYPE_MAP[(config.weight_bits, + config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits.") + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, + None] = None) -> Union[dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + cloned_config = deepcopy(config) + parallel_lm_head_quantized = isinstance( + layer, ParallelLMHead) and cloned_config.lm_head_quantized + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if get_dynamic_override( # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix) == False: # noqa: E712 + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000..fbc0f23 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if type(old) is type(new) and old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000..fbb850d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> list[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]: + """ + Queries the supported group sizes for Machete based on the activation type. + + Args: + act_type: The activation data type (torch.float16, torch.bfloat16). + + Returns: + A list of supported group sizes. The group size must + be divisible by `TileShapeK = 128 * 8 // num_bits(act_type)`. + -1 indicates per-channel quantization. + """ + if act_type in [torch.float16, torch.bfloat16]: + return [-1, 64, 128] + else: + return [-1, 128] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000..de2618e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import numpy +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +logger = init_logger(__name__) + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, +): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 80: + return [] + + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4] + else: + # GPTQ style, unsigned + symmetric bias + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) + + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + +#暂不支持marlinlinear +def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + output_size_per_partition = getattr(layer, "output_size_per_partition", + None) or layer.output_size + input_size_per_partition = getattr(layer, "input_size_per_partition", + None) or layer.input_size + + # return check_marlin_supports_shape( + # output_size_per_partition=output_size_per_partition, + # input_size_per_partition=input_size_per_partition, + # input_size=layer.input_size, + # group_size=group_size)[0] + return False + +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ + -> bool: + hidden_size = layer.hidden_size + intermediate_size_per_partition = layer.intermediate_size_per_partition + # apply_router_weight_on_input is not supported for moe marlin + supports_router_weight = not layer.apply_router_weight_on_input + # moe marlin requires the activation to be silu + supports_activation = layer.activation == "silu" + #暂时只支持bw + device_name = torch.cuda.get_device_properties(torch.cuda.current_device()).name + supports_device = "BW" in device_name + # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) + # down: (n, k) = (hidden_size, intermediate_size_per_partition) + # moe marlin requires n % 128 == 0 and k % 64 == 0 + supports_shape = hidden_size % 128 == 0 and \ + intermediate_size_per_partition % max(64, group_size) == 0 + + #暂时只支持64 + supports_group_size = group_size in [64] + return supports_shape and supports_group_size and \ + supports_router_weight and supports_activation and supports_device + + +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +# def get_scale_perms(): +# scale_perm: list[int] = [] +# for i in range(8): +# scale_perm.extend([i + 8 * j for j in range(8)]) +# scale_perm_single: list[int] = [] +# for i in range(4): +# scale_perm_single.extend( +# [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) +# return scale_perm, scale_perm_single + + +# def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, +# group_size: int) -> torch.Tensor: + +# scale_perm, scale_perm_single = get_scale_perms() +# if group_size < size_k and group_size != -1: +# s = s.reshape((-1, len(scale_perm)))[:, scale_perm] +# else: +# s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] +# s = s.reshape((-1, size_n)).contiguous() + +# return s +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(16): # 遍历列方向不同scale的 8个线程 + scale_perm.extend([i + 16 * j for j in range(8)]) # 插入 8 个数据块中 对应位置的索引 + return scale_perm + + +def marlin_permute_scales(s: torch.Tensor, # [56, 512] # torch.float16 + size_k: int, # 7168 + size_n: int, # 512 + group_size: int # 128 + ) -> torch.Tensor: + # 将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起 + scale_perm = get_scale_perms() + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + s = s.reshape((-1, size_n)).contiguous() + return s + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + # 和 scale 使用一致的重排逻辑,将[128, 128](fp16) B矩阵中 每个[16, 16]计算块中的对应位置的 zero值 放到一起 + scale_perm = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # uint4 混排 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + # uint4打包成 int32 + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + +def merge_scales_zeros(marlin_s: torch.Tensor, marlin_zp: torch.Tensor, + data_num_0: int, data_num_1: int) -> torch.Tensor: + """ + 合并两个 Tensor, 每行交替取 data_num_0 个 float16 和 data_num_1 个 int32。 + 要求: + - marlin_s 每行长度能被 data_num_0 整除 + - marlin_zp 每行长度能被 data_num_1 整除 + - 合并后的总字节数必为 4 的倍数 + + 返回: + [N, M] 的 int32 Tensor(行数一致,列数已对齐) + """ + assert marlin_s.shape[0] == marlin_zp.shape[0], "Batch size mismatch" + assert marlin_s.dtype == torch.float16 + assert marlin_zp.dtype == torch.int32 + + N, D0 = marlin_s.shape + _, D1 = marlin_zp.shape + + assert D0 % data_num_0 == 0, "marlin_s 每行必须能被 data_num_0 整除" + assert D1 % data_num_1 == 0, "marlin_zp 每行必须能被 data_num_1 整除" + + s_block_count = D0 // data_num_0 + zp_block_count = D1 // data_num_1 + assert s_block_count == zp_block_count + + total_blocks = s_block_count + + # 转为字节视图 + s_bytes = marlin_s.view(torch.uint8).reshape(N, -1) + zp_bytes = marlin_zp.view(torch.uint8).reshape(N, -1) + + # 每行的合并结果 + merged_rows = [] + + for i in range(N): + s_row = s_bytes[i] + zp_row = zp_bytes[i] + s_ptr = 0 + zp_ptr = 0 + merged = [] + + for _ in range(total_blocks): + # 如果 s 还有剩余 block,就取 + if s_ptr < s_row.numel(): + chunk_s = s_row[s_ptr: s_ptr + data_num_0 * 2] # float16 = 2 字节 + merged.append(chunk_s) + s_ptr += data_num_0 * 2 + + # 如果 zp 还有剩余 block,就取 + if zp_ptr < zp_row.numel(): + chunk_zp = zp_row[zp_ptr: zp_ptr + data_num_1 * 4] # int32 = 4 字节 + merged.append(chunk_zp) + zp_ptr += data_num_1 * 4 + + # 合并所有字节,并直接转换为 int32 + merged_bytes = torch.cat(merged) + # assert merged_bytes.numel() % 4 == 0, "最终字节长度必须是4的倍数" + merged_int32 = merged_bytes.view(torch.int32) + merged_rows.append(merged_int32) + + # 所有合并行长度一致,可以直接堆叠 + result = torch.stack(merged_rows) + return result + +def awq_marlin_moe_permute_sz( + s : torch.Tensor, + z : torch.Tensor, + size_k: int, + size_n: int, + ) -> torch.Tensor: + num_experts = s.shape[0] + + # output = torch.empty((num_experts, size_k // 16, size_n//2 + size_n//8), + # device=z.device, + # dtype=z.dtype) + + outputs = [] + for e in range(num_experts): + out_sz = merge_scales_zeros(s[e], z[e], 128, 16) + outputs.append(out_sz) + return torch.stack(outputs, dim=0) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000..ca10db6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + +logger = init_logger(__name__) + + +def is_fp4_marlin_supported(): + return current_platform.has_device_capability(80) + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000..5372c49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +def is_fp8_marlin_supported(): + return current_platform.has_device_capability(80) + + +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, + size_k_first: bool = True) -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) + + if size_k_first: + assert layer.weight.shape == (part_size_k, part_size_n) + else: + assert layer.weight.shape == (part_size_n, part_size_k) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = pack_fp8_to_int32(layer.weight, size_k_first) + if not size_k_first: + qweight = qweight.T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=8) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + if "weight_scale" in dir(layer): + scales = layer.weight_scale.to(layer.orig_dtype) + elif "weight_scale_inv" in dir(layer): + scales = layer.weight_scale_inv.to(layer.orig_dtype) + del layer.weight_scale_inv + + group_size = -1 if weight_block_size is None else weight_block_size[1] + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if weight_block_size is None: + if scales.nelement() == 1: + # tensor-wise quantization -> channel-wise quantization + # (1, 1) =>(repeat)=> (1, size_n) + scales = scales.view(1, 1).repeat_interleave(part_size_n, 1) + elif scales.nelement() > 1 and scales.nelement() != part_size_n: + assert part_size_n % scales.nelement() == 0 + s_size = scales.nelement() + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (1, s_size) =>(repeat)=> (1, size_n) + scales = scales.view(1, s_size) + scales = scales.repeat_interleave(part_size_n // s_size, 1) + else: + # channel-wise quantization + # (1, size_n) + scales = scales.view(1, part_size_n) + else: + # block-wise quantization -> group-wise quantization + # (size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.T.contiguous() + block_n = weight_block_size[0] + scales = scales.repeat_interleave(block_n, 1) + # size_n may not divisible by block_size[0] + scales = scales[:, :part_size_n] + + marlin_scales = marlin_permute_scales(s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=group_size) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, + size_k_first: bool = True) -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) + + # WORKSPACE + device = layer.w13_weight.device + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + if size_k_first: + assert weight.shape == (e, size_k, size_n) + else: + assert weight.shape == (e, size_n, size_k) + + for i in range(e): + qweight = pack_fp8_to_int32(weight[i], size_k_first) + if not size_k_first: + qweight = qweight.T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=8) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + group_size = -1 if weight_block_size is None else weight_block_size[1] + + for name in ["w13", "w2"]: + if name + "_weight_scale" in dir(layer): + new_name = name + "_weight_scale" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + elif name + "_weight_scale_inv" in dir(layer): + new_name = name + "_weight_scale_inv" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if weight_block_size is None: + if scales.nelement() == e: + # tensor-wise quantization -> channel-wise quantization + # (e, 1, 1) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2) + elif scales.nelement() > e and scales.nelement() != e * size_n: + assert (e * size_n) % scales.nelement() == 0 + s_size = scales.nelement() // e + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (e, 1, s_size) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, s_size) + scales = scales.repeat_interleave(size_n // s_size, 2) + else: + # channel-wise quantization + # (e, 1, size_n) + scales = scales.view(e, 1, size_n) + else: + # block-wise quantization -> group-wise quantization + # (e, size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (e, size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.permute(0, 2, 1) + block_n = weight_block_size[0] + scales = scales.repeat_interleave(block_n, 2) + # size_n may not divisible by block_size[0] + scales = scales[..., :size_n].contiguous() + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i], + size_k=size_k, + size_n=size_n, + group_size=group_size) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = fp8_fused_exponent_bias_into_scales(scales) + scales = torch.nn.Parameter(scales, requires_grad=False) + + setattr(layer, name + "_weight_scale", scales) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) + + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) + + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000..b2c228c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions used for tests and benchmarks""" + +from typing import Optional + +import numpy as np +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, + marlin_zero_points) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000..1c93c36 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions used for tests and benchmarks""" + +import random + +import numpy +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: list[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + quant_type.size_bits, weight_perm) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000..8a64beb --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py new file mode 100644 index 0000000..9d4a188 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +OCP_MX_BLOCK_SIZE = 32 + + +def per_token_group_quant_mxfp4(x: torch.Tensor, + block_k: int, + scale_calculation_mode: str = "even" + ) -> tuple[torch.Tensor, torch.Tensor]: + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_fp4_fp6_per_group_with_scale) + from quark.torch.quantization.utils import (even_round, + reshape_to_blocks) + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + axis = -1 + block_x = reshape_to_blocks(x, block_k, axis) + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + amax = amax.squeeze(-1) + + # TODO: there are other rounding strategies supported in quark and in the + # config.json that we do not check for here! + if scale_calculation_mode != "even": + raise NotImplementedError( + f"Scale calculation mode {scale_calculation_mode} is not yet " + "supported in MX-FP4 quantization") + scale = even_round(amax, "fp4") + + # Apply dequantize(quantize(x)). + x = fake_quantize_fp4_fp6_per_group_with_scale( + x, + scale.to(x.device), + axis=axis, + group_size=block_k, + quant_dtype="fp4", + ) + + return x, scale diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py new file mode 100644 index 0000000..fb3287d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +__all__ = [ + "break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant", + "cutlass_fp4_supported" +] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype) + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def ref_nvfp4_quant(x, global_scale, block_size): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // block_size, block_size)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = torch.clamp(scale, max=448, min=-448) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + # both outputs are float32 + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale_swizzled: torch.Tensor, + weight_global_scale: torch.Tensor): + group_size = 16 + x_m, x_k = x.shape + output_dtype = x.dtype + + # quantize input to (FP4 and interleaved block scale) + x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size) + + # dequantize input + x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size) + x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale + x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype) + del x_fp4, x_blockscale + + # dequantize weight + w_fp4 = weight.data.view(torch.uint8) + w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data, + weight_global_scale, output_dtype, x.device, + group_size) + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + del w_dq, x_dq + return out diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000..d6b9677 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,573 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""This file is used for /tests and /benchmarks""" +from collections.abc import Mapping +from types import MappingProxyType +from typing import Optional + +import numpy +import torch + +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +# Normalize the group_shape to the full extent for any dims that are -1 +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, + int]): + # -1 means full extent + return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], + group_shape[1] if group_shape[1] > 0 else x.shape[-1]) + + +# Useful when treating N-dimensional group scaling as extended numpy-style +# broadcasting in numpy simply stretches dimensions with an extent of 1 to match +# the target shape by repeating the data along that dimension (broadcasting) +# , we extend these semantics to say if the extent of a dimension in the +# source shape is not 1 and does not match the target shape we repeat each +# element along that dimension src_shape[dim] // target_shape[dim] times +# example if we have: +# a = [[1, 2], and target_shape = (2, 4) +# [3, 4]] +# then we would expand a to: +# a = [[1, 1, 2, 2], +# [3, 3, 4, 4]] +# NOTE this function this function does not explicitly broadcast dimensions +# with an extent of 1, since this can be done implicitly by pytorch +def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s % t.shape[i] == 0 + t = t.unsqueeze(i + 1)\ + .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ + .flatten(i, i + 1) + return t + + +# Quantize assuming once scale per group of elements with shape group_shape, +# example group shapes: +# * (-1, -1) for per-tensor quantization +# * (1, -1) for per-row quantization +# * (-1, 1) for per-column quantization +# * (128, 128) for 128x128 deepseek style block quantization +# * (1, 128) for deepseek style activation quantization +# (i.e. per-token-per-group) +def scaled_quantize( + x: torch.Tensor, + group_shape: tuple[int, int], + quant_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + group_shape = _normalize_quant_group_shape(x, group_shape) + assert quant_dtype.is_floating_point, \ + "currently `scaled_quantize` only supports floating point dtypes " \ + "but could be extended to support other dtypes" + + finfo = torch.finfo(quant_dtype) + + # Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N) + assert x.ndim == 2 + assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0 + blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1] + x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1]) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + x_blkd_permd = x_blkd.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + x_blkd_permd = x_blkd_permd.flatten(start_dim=2) + + # Compute scales + min_val, max_val = x_blkd_permd.aminmax(dim=-1) + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + + # Apply scale and convert form: + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\ + .clamp(min=finfo.min, max=finfo.max)\ + .reshape(blk_m, blk_n, group_shape[0], group_shape[1])\ + .permute(0, 2, 1, 3)\ + .reshape(x.shape) + + return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal() + + +# inverses `scaled_quantize` +def scaled_dequantize( + x_q: torch.Tensor, + x_s: torch.Tensor, + group_shape: Optional[tuple[int, int]] = None, + out_dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + if group_shape is not None: + group_shape = _normalize_quant_group_shape(x_q, group_shape) + + if x_s.ndim == 0: # scalar + x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor + if x_s.ndim == 1: + if group_shape is None: + raise AssertionError( + "if x_s is 1D tensor, group_shape must be provided otherwise " + "its ambiguous which dimension to broadcast x_s to") + # unsqueeze the scales for the dimension where we want to broadcast + # across the full extent + if group_shape[0] == x_q.shape[-2]: + x_s = x_s.unsqueeze(-2) + elif group_shape[1] == x_q.shape[-1]: + x_s = x_s.unsqueeze(-1) + else: + raise AssertionError( + "if x_s is a vector we should be broadcasting it to the full " + "extent of one of the dimensions") + + if group_shape is not None: + assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] + assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0] + x_s = group_broadcast(x_s.to(torch.float32), x_q.shape) + return (x_q.to(torch.float32) * x_s).to(out_dtype) + + +def pack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped( + prefix: str, + ignored_layers: list[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) +) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, \ + "to have group zero points, group_size must be provided "\ + "(-1 group_size is channelwise)" + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / + s_channel).to(dtype=torch.half) + else: + max_q_val = 2**(num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= (2**(8 - num_bits)) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, size_n // pack_factor + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/vllm/model_executor/layers/quantization/utils/w4a8_utils.py b/vllm/model_executor/layers/quantization/utils/w4a8_utils.py new file mode 100644 index 0000000..bed3963 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/w4a8_utils.py @@ -0,0 +1,93 @@ + +import torch +import numpy as np + +try: + from lightop import awq_marlin_repack_w4a8 + use_lightop = True +except Exception: + use_lightop = False + +def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor: + """ + 将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。 + 每个int8包含两个int4,分别提取到int32的低4位,其余位为0。 + + Args: + tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。 + + Returns: + torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。 + """ + if tensor_int8.dtype != torch.int8: + raise ValueError("Input tensor must be of type torch.int8") + + N, K_half = tensor_int8.shape + tensor_uint8 = tensor_int8.to(torch.uint8) + high4 = tensor_uint8 & 0x0F + low4 = (tensor_uint8 >> 4) & 0x0F + unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device) + unpacked[:, 0::2] = low4.to(torch.int32) + unpacked[:, 1::2] = high4.to(torch.int32) + + return unpacked + +def get_weight_perms(interleave: bool=True): + perm = [] + for i in range(64): + + for col in range(4): + cur_col = (i % 16) * 4 + col + for row in range(8): + cur_row = (i // 16) * 8 + row + cur_idx = cur_row * 64 + cur_col + perm.append(cur_idx) + + perm = np.array(perm) + if interleave: + interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + + perm = torch.from_numpy(perm) + + return perm + +def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8): + size_k, size_n = q_w.shape + q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // k_tile, size_n * k_tile)) + q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape) + + orig_device = q_w.device + q_w = q_w.contiguous().to(torch.int32) + M, N = q_w.shape + assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})" + q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device) + for i in range(pack_factor): + q_packed += q_w[:, i::pack_factor] << (4 * i) + + return q_packed + +def w4a8_2_marlin_weight(w4a8_w): + full_w4a8_w = unpack_int8_to_int4(w4a8_w) + full_w4a8_w = full_w4a8_w.T + weight_perm = get_weight_perms() + marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8) + return marlin_q_w + +def w4a8_weight_repack_impl(input): + if use_lightop: + size_batch = input.shape[0] + size_n = input.shape[1] + size_k = input.shape[2] * 2 + output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32) + awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n) + else: + w_marlin_list = [] + for e in range(input.shape[0]): + w_marlin_in = w4a8_2_marlin_weight(input[e]) + w_marlin_list.append(w_marlin_in) + output = torch.stack(w_marlin_list, dim=0) + + return output \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py new file mode 100644 index 0000000..0232491 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -0,0 +1,506 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional, Union + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.platforms import current_platform +from vllm.utils import W8a8GetCacheJSON +from lmslim.layers.gemm.int8_utils import per_token_quant_int8 + +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = None +W8A8_TRITONJSON=W8a8GetCacheJSON() + +# The condition to determine if it is on a platform that supports +# torch._scaled_mm rowwise feature. +# The condition is determined once as the operations +# are time consuming. +USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() + and torch.__version__[0:3] >= "2.7" + and current_platform.has_device_capability(94)) + +def sparse_cutlass_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_sparse_scaled_mm_supported(capability) + + +def cutlass_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +def cutlass_block_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_block_fp8(capability) + + +def cutlass_group_gemm_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_group_gemm_supported(capability) + + +CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() +CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def convert_to_channelwise( + weight_scale: torch.Tensor, + logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: + # Create channelwise buffer + weight_scale_channel = torch.empty((sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device) + + # Expand each scale to match the size of each logical matrix. + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = weight_scale[idx] + start = end + + return weight_scale_channel + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = (weight_scale[-1] + > torch.finfo(torch.float8_e4m3fn).min) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant( + weight_dq, max_w_scale) + start = end + + return max_w_scale, weight + + +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + +def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, + out_dtype: torch.dtype, scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + output_shape: list, **kwargs) -> torch.Tensor: + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + return output.view(*output_shape) + + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + from vllm.platforms.rocm import on_mi3xx + if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( + ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: + output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, + current_platform.get_cu_count()) + else: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + +def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM + # when using it. + # For now it has only been validated on ROCm platform. + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using + # hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above. + # + # For CUDA platform please validate if the torch._scaled_mm supports + # rowwise scaled GEMM before using it + + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm(qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b.t(), + bias=bias) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + +def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list, + **kwargs) -> torch.Tensor: + # Use unfused DQ due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * scale_b.t() + if bias is not None: + output = output + bias + return output.to(out_dtype).view(*output_shape) + + +def dispatch_w8a8_scaled_mm( + cutlass_fp8_supported: bool, per_tensor_weights: bool, + per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] +) -> Callable[..., torch.Tensor]: + + if cutlass_fp8_supported: + return cutlass_w8a8_scaled_mm + if per_tensor_weights and per_tensor_activations: + if current_platform.is_rocm(): + return rocm_per_tensor_w8a8_scaled_mm + return torch_per_tensor_w8a8_scaled_mm + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + if (use_per_token_if_dynamic and not per_tensor_weights + and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + return torch_per_token_w8a8_scaled_mm + return torch_channelwise_w8a8_scaled_mm + + +# TODO(luka): follow similar pattern for marlin and block-fp8-linear +# https://github.com/vllm-project/vllm/issues/14397 +class Fp8LinearOp: + """ + This class executes a FP8 linear layer using cutlass if supported and + torch.scaled_mm otherwise. + It needs to be a class instead of a method so that config can be read + in the __init__ method, as reading config is not allowed inside forward. + """ + + def __init__(self, + cutlass_fp8_supported: bool = cutlass_fp8_supported(), + use_per_token_if_dynamic: bool = False, + pad_output: Optional[bool] = None): + self.cutlass_fp8_supported = cutlass_fp8_supported + self.use_per_token_if_dynamic = use_per_token_if_dynamic + + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + if pad_output is None: + config = get_current_vllm_config().compilation_config + pad_output = config.level < CompilationLevel.PIECEWISE + self.output_padding = 17 if ( + pad_output and not current_platform.is_rocm()) else None + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + # TODO(luka) remove this parameter in favor of __init__ + use_per_token_if_dynamic: Optional[bool] = None + ) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + # TODO(luka) this is here because currently MLA only decides this + # during the forward method instead of in __init__. + if use_per_token_if_dynamic is None: + use_per_token_if_dynamic = self.use_per_token_if_dynamic + + if out_dtype is None: + out_dtype = input.dtype + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if self.cutlass_fp8_supported: + assert input.dtype != current_platform.fp8_dtype( + ), "FP8 input to cutlass is not currently implemented" + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + if input.dtype != current_platform.fp8_dtype(): + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + qinput, x_scale = input_2d, input_scale + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( + self.cutlass_fp8_supported, per_tensor_weights, + per_tensor_activations, use_per_token_if_dynamic) + + return w8a8_scaled_mm_func(qinput=qinput, + weight=weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + input_2d=input_2d, + output_shape=output_shape) + + +def apply_int8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_zero_point: Optional[torch.Tensor] = None, + azp_adj: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + w8a8_strategy:Optional[int]=0, +): + # ops.scaled_int8_quant supports both dynamic and static quant. + # * dynamic, layer.input_scale is None and x_scale computed from x. + # * static, layer.input_scale is scalar and x_scale is input_scale. + symmetric = azp_adj is None + if input_scale is None and input_zero_point is None and symmetric is True: + x_q, x_scale=per_token_quant_int8(input) + x_zp =None + + else: + x_q, x_scale, x_zp = ops.scaled_int8_quant(input, + input_scale, + input_zero_point, + symmetric=symmetric) + + if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = input_zero_point is not None + azp = None if static else x_zp + return ops.cutlass_scaled_mm_azp(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias) + if w8a8_strategy==1: + m=x_q.shape[0] + k=x_q.shape[1] + n=weight.shape[1] + #print("m:{},n:{},k:{}".format(m,n,k)) + if len(W8A8_TRITONJSON.triton_json_dict)==0: + best_config=None + + elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict: + if m<=16: + m_=m + #best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"] + elif m<=64: + m_= (m + 3) & -4 #取值到最近的4的倍数 + elif m<=160: + m_=(m + 7) & -8 + + elif m<200: #256 + m_=160 + elif m<480: #512 + m_=256 + elif m<960: #1024 + m_=512 + elif m<2048: + m_=1024 + elif m<4096: + m_=2048 + elif m<6000: + m_=4096 + else: + m_=8192 + + best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"] + + else: + best_config=None + # if best_config==None: + # print("m:{},n:{},k:{}".format(m,n,k)) + # print("config not found!") + + + return ops.triton_scaled_mm(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + bias=bias,best_config=best_config) + elif w8a8_strategy==2: + return ops.cutlass_scaled_mm(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + bias=bias) + else: + return ops.rocblas_scaled_mm(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + bias=bias) + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 0000000..db68f18 --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from functools import cached_property +from importlib.util import find_spec +from typing import Optional + +import torch +import torch.jit + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeStochasticBaseSampler) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +if find_spec("flashinfer"): + """ + Consider utilizing the FlashInfer rejection sampling kernel initially, + as it employs a dedicated kernel rather than relying on + Torch tensor operations. This design choice helps to fuse operations, + reduce memory I/O, and consequently enhances performance. + """ + from flashinfer.sampling import chain_speculative_sampling +else: + chain_speculative_sampling = None + + +class RejectionSampler(SpecDecodeStochasticBaseSampler): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, + strict_mode: bool = False, + use_flashinfer: Optional[bool] = None): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + use_flashinfer: We will use this parameter to determine whether + to use the FlashInfer rejection sampling kernel or not. If it's + None, we will use the default value from the environment variable. + This parameter is only used for testing purposes. + """ + super().__init__(strict_mode=strict_mode) + if use_flashinfer is None: + self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( + chain_speculative_sampling is not None) + else: + self.use_flashinfer = use_flashinfer + + if self.use_flashinfer: + logger.info("Use flashinfer for rejection sampling.") + else: + logger.info("Use pytorch for rejection sampling.") + + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + seeded_seqs: Optional[dict[int, torch.Generator]] = None, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_with_bonus_probs: The probability distribution + over token ids given context according to the target model. + shape = [batch_size, num_speculative_tokens + 1, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + seeded_seqs: Dict of batch row index to torch generator, for + sequences using seeded generation. + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + + batch_size, k, _ = draft_probs.shape + + # batch_size = 0 when all requests in the batch are + # non_spec requests. In this case, output_token_ids is + # just an empty tensor. + if batch_size == 0: + return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) + + # If use Flashinfer chain_speculative_sampling kernel + # for rejection sampling + if self.use_flashinfer and chain_speculative_sampling is not None: + batch_size, k, _ = draft_probs.shape + + (output_token_ids, accepted_token_num, + emitted_token_num) = chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_with_bonus_probs, + ) + + # num_emitted_tokens returned by flashinfer + # does not include the bonus token + # Flashinfer stops at the first token that violates + # the condition p >= q and does not include recovery/bonus token. + # Therefore, we need to add batch_size here. + self.num_accepted_tokens += accepted_token_num.sum() + self.num_emitted_tokens += emitted_token_num.sum() + batch_size + self.num_draft_tokens += batch_size * k + else: + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_with_bonus_probs[:, :-1], + draft_probs, + draft_token_ids, + seeded_seqs, + )) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + seeded_seqs: Optional[dict[int, torch.Generator]], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids, seeded_seqs) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + # NOTE: the recovered_probs are overwritten by this method. + recovered_token_ids = _multinomial( + recovered_probs, + num_samples=1, + k=k, + seeded_seqs=seeded_seqs or {}, + ).reshape(batch_size, k) + + return accepted, recovered_token_ids + + def _create_uniform_samples(self, + seeded_seqs: Optional[dict[int, + torch.Generator]], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k + 1)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. If `None`, all samples are + generated without a seed. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k + 1)` containing uniform + random values in the range [0, 1). + """ + if not seeded_seqs: + return torch.rand(batch_size, k + 1, device=device) + + uniform_rand = torch.empty(batch_size, k + 1, device=device) + + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand(1, + k + 1, + dtype=self.probs_dtype, + device=device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), + k + 1, + dtype=self.probs_dtype, + device=device) + return uniform_rand + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + seeded_seqs: Optional[dict[int, torch.Generator]], + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of + $\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according + to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the + same conditional probability according to the draft model, the token + is accepted with probability: + + $$ + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + $$ + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indices = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indices, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indices, + draft_token_ids] + + uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, + k - 1, target_probs.device) + + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given $q(x|x_1, \dots, x_n)$, the probability of + $x$ given context $x_1, \dots, x_n$ according to the target + model and $p(x|x_1, \dots, x_n)$, the same conditional probability + according to the draft model: + + $$ + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + $$ + + where $(f(x))_+$ is defined as: + + $$ + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + $$ + + See https://github.com/vllm-project/vllm/pull/2336 for a visualization + of the draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: + This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def _multinomial( + probs: torch.Tensor, + num_samples: int, + k: int, + seeded_seqs: dict[int, torch.Generator], +) -> torch.Tensor: + + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs) + if not seeded_seqs: + q.exponential_(1.0) + else: + start = 0 + for idx in range(len(q) // k): + end = start + k + generator = seeded_seqs.get(idx) + # Note: generator might be None for non seeded + q[start:end].exponential_(1.0, generator=generator) + start = end + + return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py new file mode 100644 index 0000000..3f2d571 --- /dev/null +++ b/vllm/model_executor/layers/resampler.py @@ -0,0 +1,270 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +# +# Copyright 2023 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared resampler perceiver network used in multimodal models and +related helpers for sincos positional embeddings. + +Example models: Qwen (Qwen-VL), MiniCPM-V 2.0 +""" +import math +from functools import partial +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, + int]) -> torch.Tensor: + # abs_pos: L, C + # tgt_size: (H, W) + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + dtype = abs_pos.dtype + if isinstance(tgt_size, int): + tgt_size = (tgt_size, tgt_size) + if (src_size == tgt_size[0] and src_size == tgt_size[1]): + return abs_pos + return (F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, + version: tuple[int, int] = (2, 0)) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, + version: tuple[int, int] = (2, 0)) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, tuple[int, int]], + cls_token: bool = False, + version: tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and \ + grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim)) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear(kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj") + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = nn.Parameter( + (embed_dim**-0.5) * + torch.empty(embed_dim, embed_dim)) if do_post_projection else None + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + """Resampler-perceiver network to be used for a variety of model types, + e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the + do_post_projection arg, which indicates whether or not there should be + a post layer normalization and projector after the attention. This is + present in minicpmv2.0, but not qwen-vl. + """ + + def __init__(self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix) + + self.adaptive = adaptive + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).requires_grad_(False)) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if tgt_sizes is None: + tgt_sizes = int(math.sqrt(x.size(1))) + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) + else: + pos_embed = get_abs_pos(self.pos_embed, + tgt_sizes).to(device=x.device, + dtype=x.dtype) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + if self.do_post_projection: + x = self.ln_post(x) + x = x @ self.proj + return x diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py new file mode 100644 index 0000000..a7c5ec5 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -0,0 +1,2089 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import itertools +import math +from typing import Any, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from transformers import PretrainedConfig + +from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform + +if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb +if current_platform.is_rocm(): + from flash_attn.layers.rotary import apply_rotary_emb + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + if current_platform.is_cuda(): + return apply_rotary_emb(x.unsqueeze(0), cos, sin, + not is_neox_style).squeeze(0) + else: + return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) + + +@CustomOp.register("rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + # key may be None in some cases, e.g. cross-layer KV sharing + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + from vllm import _custom_ops as ops + + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible + if self.cos_sin_cache.device != query.device or \ + self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if key is None: + # XPU kernel doesn't support key=None so fall back to native impl + # TODO(sarckk): add support for optional key in + # ipex.llm.functional.rotary_embedding_batched + return self.forward_native(positions, query, key, offsets) + else: + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + self.rotary_dim, offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, apply_rotary_pos_emb) + if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view( + num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, + rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, + rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_neuron( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + def _apply_rotary_emb_neuron( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + ) -> torch.Tensor: + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + # x1 = x[..., ::2] + + # x2 = x[..., 1::2] + d = x.shape[-1] // 2 + x_reshaped = x.view(-1, x.shape[-1]) + x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d) + x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d) + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + if offsets is not None: + positions = positions + offsets + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + if key is not None: + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + + if self.rotary_dim == self.head_size: + query = _apply_rotary_emb(query, cos, sin, self.is_neox_style) + query = query.reshape(query_shape) + if key is not None: + key = _apply_rotary_emb(key, cos, sin, self.is_neox_style) + key = key.reshape(key_shape) + else: + head_size = query.shape[-1] + query_reshaped = query.view(-1, head_size) + query_pass = query_reshaped[:, self.rotary_dim:].view( + *query.shape[:-1], head_size - self.rotary_dim) + query_rot = query_reshaped[:, :self.rotary_dim].view( + *query.shape[:-1], self.rotary_dim) + query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), + dim=-1).reshape(query_shape) + + if key is not None: + key_reshaped = key.view(-1, head_size) + key_pass = key_reshaped[:, self.rotary_dim:].view( + *key.shape[:-1], head_size - self.rotary_dim) + key_rot = key_reshaped[:, :self.rotary_dim].view( + *key.shape[:-1], self.rotary_dim) + key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factors: Union[list[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: list[float] = scaling_factors # noqa + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: list[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: list[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> dict[float, int]: + return self._scaling_factor_to_offset + + +class NTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with fixed and mixed NTK scaling. + https://kexue.fm/archives/9706 """ + + def __init__(self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + mixed_b: Optional[float] = None) -> None: + self.scaling_factor = scaling_factor + self.mixed_b = mixed_b + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + base = self.base * (self.scaling_factor if self.mixed_b is None else 1) + inv_freq = super()._compute_inv_freq(base) + + if self.mixed_b is None: + inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) + else: + a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / + 2)**self.mixed_b + lambda_1_m = (a * torch.arange( + 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() + inv_freq = inv_freq / lambda_1_m + + return inv_freq + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK alpha. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_alpha: float, + dtype: torch.dtype, + ) -> None: + self.scaling_alpha = scaling_alpha + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # For Hunyuan DynamicNTKAlphaRotaryEmbedding + max_len = self.max_position_embeddings + base = self.base * self.scaling_alpha**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: list[float], + long_factor: list[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.rotary_dim = rotary_dim + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(dtype) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(dtype) + + long_short_cache = torch.cat([short_cache, long_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: list[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = query_rot * cos + _rotate_neox(query_rot) * sin + query = torch.cat((query_rot, query_pass), dim=-1) + + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = key_rot * cos + _rotate_neox(key_rot) * sin + key = torch.cat((key_rot, key_pass), dim=-1) + + return query.flatten(-2), key.flatten(-2) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +@triton.jit +def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int, + stride2: int, stride_cs: int, + dim1: int, dim2: int, dim3: int, + BLOCK_SIZE: tl.constexpr): + pid0 = tl.program_id(0) + pid1 = tl.program_id(1) + pid2 = tl.program_id(2) + offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE + offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2 + + offsets = pid0 * stride1 + pid1 * stride2 + offsets_q + mask = offsets_cs < dim3 + mask2 = offsets_q < dim3 * 2 + + v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask) + v_cos2 = tl.interleave(v_cos, v_cos) + v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask) + v_sin2 = tl.interleave(v_sin, v_sin) + x12 = tl.load(q + offsets, mask=mask2) + x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2])) + # we are both reading and writing 'q'; make sure all warps are in sync + tl.debug_barrier() + x12_ = tl.ravel(tl.join(-x2, x1)) + x12 = x12 * v_cos2 + x12_ * v_sin2 + tl.store(q + offsets, x12, mask=mask2) + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + reference: bool = False, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.reference = reference + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + assert key is not None + + if self.cos_sin_cache.device != positions.device: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + if query.device.type == 'cuda' and not self.is_neox_style \ + and not self.reference: + assert len(query.shape) == 3 + + def call(q): + BLOCK_SIZE = 64 + grid = ( + q.shape[-3], + q.shape[-2], + triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE), + ) + deepseek_scaling_rotary_emb_kernel_gptj[grid]( + cos_sin, + q, + stride1=q.stride()[-3], + stride2=q.stride()[-2], + stride_cs=cos_sin.stride()[-2], + dim1=q.shape[0], + dim2=q.shape[1], + dim3=self.rotary_dim // 2, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1) + + call(query) + call(key) + return query, key + else: + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + + +class Llama4VisionRotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ): + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + inv_freqs = inv_freqs[:(self.rotary_dim // 2)] + return inv_freqs + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + + # self.max_position_embeddings here is number of image patches + # i.e. (image_size // patch_size) ** 2 + num_patches = self.max_position_embeddings + img_idx = torch.arange(num_patches, + dtype=torch.int32) \ + .reshape(num_patches, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN + num_patches_single_dim = int(math.sqrt(num_patches)) + frequencies_x = img_idx % num_patches_single_dim + frequencies_y = img_idx // num_patches_single_dim + freqs_x = ((frequencies_x + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_y = ((frequencies_y + 1)[..., None] * + inv_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], + dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + cache = torch.view_as_complex( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + return cache + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert key is not None + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) + query_ = torch.view_as_complex(query.float().reshape( + *query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape( + *key.shape[:-1], -1, 2)) + broadcast_shape = [ + d if i == 1 or i == (query_.ndim - 1) else 1 + for i, d in enumerate(query_.shape) + ] + freqs_ci = self.cos_sin_cache.view(*broadcast_shape) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) + + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[list[int]] = None, + ) -> None: + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get + # a larger the cos and sin cache. + self.cache_max_position_num = max_position_embeddings * 4 + super().__init__(head_size, rotary_dim, self.cache_max_position_num, + base, is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @classmethod + def get_input_positions( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]], + second_per_grid_ts: Optional[list[float]], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[list[list[int]], int]: + """Get mrope input positions and delta value.""" + + image_grid_thw = [] if image_grid_thw is None else image_grid_thw + video_grid_thw = [] if video_grid_thw is None else video_grid_thw + second_per_grid_ts = [] if second_per_grid_ts is None else \ + second_per_grid_ts + + llm_positions, mrope_position_delta = \ + cls.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + + return llm_positions.tolist(), mrope_position_delta + + @classmethod + def get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): + return cls._omni_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) + elif "glm4v" in hf_config.model_type: + return cls._glm4v_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) + else: + return cls._vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + ) + + @classmethod + def _glm4v_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.image_token_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_frame_num, + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + + @classmethod + def _vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: list[float], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, + "tokens_per_second", 1.0) + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @classmethod + def _omni_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + second_per_grid_ts: Optional[list[float]] = None, + context_len: int = 0, + seq_len: Optional[int] = None, + audio_feature_lengths: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value (Qwen2.5-Omni version). + + Differences from MRotaryEmbedding: + 1. Add audio support (and related `audio_feature_lengths`). + 2. Add `use_audio_in_video` option to read audio from video inputs. + In this case, audio and vision position ids will be split into + chunks and interleaved. + + Example: + + (V_i are vision position ids, A_i are audio position ids) + + |V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|... + |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |... + """ + + # TODO(fyabc): refactor and share more code with + # _vl_get_input_positions_tensor. + + thinker_config = hf_config.thinker_config + audio_token_id = thinker_config.audio_token_index + image_token_id = thinker_config.image_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id + vision_end_token_id = thinker_config.vision_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + src_item = input_tokens + audio_seqlens = audio_feature_lengths + if not second_per_grid_ts: + second_per_grid_ts = [1] * video_grid_thw.shape[0] + audio_idx = 0 + video_idx = 0 + image_idx = 0 + new_src_item: list[int] = [] + llm_pos_ids_list: list[torch.Tensor] = [] + + idx = 0 + while idx < len(src_item): + new_src_item_len = len(new_src_item) + start_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if src_item[idx] not in [ + audio_token_id, video_token_id, image_token_id + ]: + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> + start_idx -= 1 + new_src_item.append(src_item[idx]) + llm_pos_ids = torch.tensor([start_idx], + dtype=torch.long).expand(3, -1) + llm_pos_ids_list.append(llm_pos_ids) + elif src_item[idx] == audio_token_id: + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + new_src_item.extend([audio_token_id] * place_num) + llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx + llm_pos_ids_list.append(llm_pos_ids) + audio_idx += 1 + elif src_item[idx] == image_token_id: + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([image_token_id] * vision_seqlen) + image_idx += 1 + elif src_item[idx] == video_token_id and not use_audio_in_video: + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, + grid_ws) + llm_pos_ids_list.append(llm_pos_ids) + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + new_src_item.extend([video_token_id] * vision_seqlen) + video_idx += 1 + else: + # read audio from video + assert audio_seqlens is not None + audio_seqlen = audio_seqlens[audio_idx] + vision_seqlen = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + grid_t = video_grid_thw[video_idx][0] + grid_h = video_grid_thw[video_idx][1] + grid_w = video_grid_thw[video_idx][2] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * + second_per_grid_ts[video_idx] * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 + pure_audio_len = place_num - 2 + added_audio_len = 0 + audio_llm_pos_ids_list: list[torch.Tensor] = [] + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len( + t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + new_src_item.extend([video_token_id] * + vision_ntoken_per_chunk) + vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( + start_idx, video_idx, spatial_merge_size, t_chunk, + grid_hs, grid_ws).split(1, dim=1) + llm_pos_ids_list.extend(vision_llm_pos_ids_list) + new_src_item.extend( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len) * [audio_token_id]) + audio_start_idx = start_idx if len( + audio_llm_pos_ids_list + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 + if min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = (torch.arange( + min(t_ntoken_per_chunk, pure_audio_len - + added_audio_len)).expand(3, -1) + + audio_start_idx).split(1, + dim=1) + else: + audio_llm_pos_ids_list = [] + added_audio_len += min(t_ntoken_per_chunk, + pure_audio_len - added_audio_len) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + if added_audio_len < pure_audio_len: + new_src_item.extend( + (pure_audio_len - added_audio_len) * [audio_token_id]) + audio_llm_pos_ids_list = ( + torch.arange(pure_audio_len - added_audio_len).expand( + 3, -1) + llm_pos_ids_list[-1].max() + 1).split( + 1, dim=1) + llm_pos_ids_list.extend(audio_llm_pos_ids_list) + audio_idx += 1 + video_idx += 1 + # move to the next token + idx += len(new_src_item) - new_src_item_len + + llm_positions = torch.cat(llm_pos_ids_list, dim=1) + mrope_position_delta = torch.cat(llm_pos_ids_list, + dim=1).max() + 1 - len(src_item) + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + + @staticmethod + def _get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: list[int], + grid_hs: torch.Tensor, + grid_ws: torch.Tensor, + ) -> torch.Tensor: + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( + len(t_index), -1, llm_grid_w).flatten()) + w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( + len(t_index), llm_grid_h, -1).flatten()) + t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( + -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _split_list_into_ranges(lst: torch.Tensor, + interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] + for _ in range((max(lst) // interval) + 1)] + for num in lst: + index = num // interval + ranges[index].append(num) + return ranges + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> list[list[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + @staticmethod + def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, + mrope_position_delta: int, + context_len: int, num_new_tokens: int): + + values = np.arange(mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype) + out[:, out_offset:out_offset + num_new_tokens] = values + + @classmethod + def omni_get_updates_use_audio_in_video( + cls, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + """Get video prompt updates when `use_audio_in_video` is True. + + In this case, audio and vision update ids will be split into + chunks and interleaved (details in `_omni_get_input_positions_tensor`). + + <|video_bos|><|VIDEO|><|video_eos|> => + <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|> + """ + + audio_token_id = thinker_config.audio_token_index + video_token_id = thinker_config.video_token_index + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + seconds_per_chunk = thinker_config.seconds_per_chunk + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + tokens_per_second = getattr(thinker_config.vision_config, + "tokens_per_second", 25) + + grid_t = video_grid_thw[0] + grid_h = video_grid_thw[1] + grid_w = video_grid_thw[2] + t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) + t_index = (torch.arange(grid_t) * video_second_per_grid_t * + tokens_per_second).long() + t_index_split_chunk = cls._split_list_into_ranges( + t_index, t_ntoken_per_chunk) + + updates = [audio_start_token_id] + added_audio_len = 0 + for t_chunk in t_index_split_chunk: + vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( + spatial_merge_size**2) + updates.extend([video_token_id] * vision_ntoken_per_chunk) + + audio_chunk_size = min(t_ntoken_per_chunk, + audio_len - added_audio_len) + updates.extend(audio_chunk_size * [audio_token_id]) + added_audio_len += audio_chunk_size + if added_audio_len < audio_len: + updates.extend((audio_len - added_audio_len) * [audio_token_id]) + updates.extend([audio_end_token_id]) + + return updates + + +@CustomOp.register("dual_chunk_rotary_embedding") +class DualChunkRotaryEmbedding(CustomOp): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, + q_inter_cache) = self._compute_cos_sin_cache() + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer("cos_sin_qc_no_clamp_cache", + qc_no_clamp_cache, + persistent=False) + self.register_buffer("cos_sin_q_inter_cache", + q_inter_cache, + persistent=False) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + + chunk_len).clamp(max=self.chunk_size) + k_t = torch.arange(self.max_position_embeddings, + dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, + dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + else: + query_pass = None + key_pass = None + + positions_with_offsets = (torch.add(positions, offsets) + if offsets is not None else positions) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, query_pass) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + + # merge query into one tensor to simplify the interfaces + query = torch.cat(( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s + + +_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: float, + is_neox_style: bool = True, + rope_scaling: Optional[dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling_args, dual_chunk_attention_args, dtype) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + **extra_kwargs) + elif not rope_scaling: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + else: + scaling_type = rope_scaling["rope_type"] + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "mllama4": + rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype) + elif scaling_type == "ntk": + scaling_factor = rope_scaling["factor"] + mixed_b = rope_scaling.get('mixed_b', None) + rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype, + mixed_b) + elif scaling_type == "dynamic": + if "alpha" in rope_scaling: + scaling_alpha = rope_scaling["alpha"] + rotary_emb = DynamicNTKAlphaRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_alpha, dtype) + elif "factor" in rope_scaling: + scaling_factor = rope_scaling["factor"] + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor, dtype) + else: + raise ValueError("Dynamic rope scaling must contain either " + "'alpha' or 'factor' field") + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, dtype, + **extra_kwargs) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow", "mscale", "mscale_all_dim") + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py new file mode 100644 index 0000000..bcc3606 --- /dev/null +++ b/vllm/model_executor/layers/sampler.py @@ -0,0 +1,1224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A layer that samples the next tokens from the model's outputs.""" +import os +import itertools +from collections.abc import Iterator +from dataclasses import dataclass +from importlib.util import find_spec +from math import inf +from typing import Optional, Union + +import msgspec +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.model_executor.layers.utils import apply_penalties +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): + # yapf: disable + from flashinfer.sampling import ( + top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) + + # yapf: enable +else: + flashinfer_top_k_top_p_sampling = None + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_sampler() -> torch.nn.Module: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.sample.sampler import Sampler as V1Sampler + return V1Sampler() + if envs.VLLM_ZERO_OVERHEAD: + from vllm.zero_overhead.sampler import ZeroOverheadSampler + return ZeroOverheadSampler() + return Sampler() + + +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = list[tuple[list[int], list[int]]] + +# Types of temporary data structures used for +# computing sample_result +SampleMetadataType = dict[SamplingType, tuple[list[int], + list[SequenceGroupToSample]]] +MultinomialSamplesType = dict[SamplingType, torch.Tensor] +SampleResultsDictType = dict[int, tuple[list[int], list[int]]] + + +# Encapsulates temporary data structures for computing +# sample_result. +# +# * For multi-step scheduling: must be returned +# by `Sampler.forward()` and used later to compute the pythonized +# sample_result +# +# * For single-step scheduling: consumed immediately +# inside `Sampler.forward()` to compute pythonized sample_result. +@dataclass +class SampleResultArgsType: + sample_metadata: SampleMetadataType + multinomial_samples: MultinomialSamplesType + sample_results_dict: SampleResultsDictType + sampling_metadata: SamplingMetadata + greedy_samples: Optional[torch.Tensor] + + +# Union of non-deferred (single-step scheduling) +# vs deferred (multi-step scheduling) +# sample result types +MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] + +# Abbreviation of the _sample() return type +SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] + + +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This data structure implements methods, so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: list[CompletionSequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional[torch.Tensor] = None + + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + + # Holds either (1) the pythonized sampler result (single-step scheduling) + # or (2) what will be arguments for later deferred pythonization of the + # sampler result (muliti-step scheduling) + deferred_sample_results_args: Optional[SampleResultArgsType] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional[torch.Tensor] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None + + # On-device tensor containing the sampled token embeddings (embeddings + # corresponding to the sampled token ids). Used when prompt embeddings are + # specified in lieu of prompt token ids or text. + sampled_token_embeds: Optional[torch.Tensor] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + + # Optional lm_head logits from the model. + logits: Optional[torch.Tensor] = None + + # tree-style cartesian candidates + cart_candidates: Optional[torch.Tensor] = None + + # tree-style cartesian candidates + tree_attn_masks: Optional[torch.Tensor] = None + + + def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: + return iter(self.outputs) + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics}, " + f"logits={self.logits}, " + f"tree_attn_masks={self.tree_attn_masks})") + + +class Sampler(nn.Module): + """Samples the next tokens from the model's outputs. + + This layer does the following: + 1. Discard the hidden states that are not used for sampling (i.e., all + tokens except the final one in each prompt). + 2. Compute the logits for the next tokens. + 3. Apply presence, frequency and repetition penalties. + 4. Apply temperature scaling. + 5. Apply top-p and top-k truncation. + 6. Sample the next tokens. + Here, each sequence group within the batch can have different sampling + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. + """ + + def __init__(self): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding and when prompt embeddings are specified. + self.include_gpu_probs_tensor = False + self.should_modify_greedy_probs_inplace = False + self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1') + + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + """ + Single-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor + + Multi-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the + [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput] + structure + + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ + assert logits is not None + _, vocab_size = logits.shape + + # Prepare sampling tensors with pinned memory to avoid blocking. + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Apply presence and frequency penalties. + if do_penalties: + logits = apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. + on_device_tensors = None + + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) + + return _build_sampler_output( + maybe_deferred_sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output, + logits=logits if self.tree_decoding else None) + + @property + def _should_modify_greedy_probs_inplace(self) -> bool: + """Whether or not the sampler should modify the probability distribution + of greedily-sampled tokens such that multinomial sampling would sample + the greedily-sampled token. + + In other words, if True then we set the probability of the greedily- + sampled token to 1. + + This is used by speculative decoding, which requires that the sampling + method be encoded into the probability distribution. + """ + return self.should_modify_greedy_probs_inplace + + +def _apply_min_tokens_penalty( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + have not been generated yet + """ + # list of indices in logits that will be set to -inf + logits_to_penalize: list[tuple[int, int]] = [] + logits_applied = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + + sample_indices = seq_group.sample_indices + logits_applied += len(sample_indices) + len( + seq_group.prompt_logprob_indices) + if not seq_group.do_sample: + continue + + start_idx = sample_indices[0] + min_tokens = sampling_params.min_tokens + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: + seqs_to_penalize: list[int] = [] + for j, seq_id in enumerate(seq_ids): + seq_data = seq_group.seq_data[seq_id] + if len(seq_data.output_token_ids_array) < min_tokens: + seqs_to_penalize.append(j) + + if seqs_to_penalize: + # convert to the index into logits + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] + # itertools.product pairs each seq index with every token id + logits_to_penalize.extend( + itertools.product(seqs_to_penalize, token_ids_to_penalize)) + + if logits_to_penalize: + # use zip and * to group indices along each dimension + # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) + logits[tuple(zip(*logits_to_penalize))] = -float("inf") + + # verifies that no rows in logits were missed unexpectedly + assert logits_applied == logits.shape[0] + return logits + + +def _apply_top_k_top_p( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) + return logits + + +def _apply_min_p( + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Adapted from + https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 + """ + probs = torch.softmax(logits, dim=-1) + top_probs, _ = probs.max(dim=-1, keepdim=True) + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs + tokens_to_remove = probs < scaled_min_p + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) + + return logits + + +def _greedy_sample( + selected_seq_groups: list[SequenceGroupToSample], + samples: torch.Tensor, +) -> SampleResultType: + """Run greedy sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + samples_lst = samples.tolist() + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + num_parent_seqs = len(seq_ids) + assert num_parent_seqs == 1, ( + "Greedy sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + next_token_ids = [samples_lst[sample_idx]] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +def _random_sample( + selected_seq_groups: list[SequenceGroupToSample], + random_samples: torch.Tensor, +) -> SampleResultType: + """Run random sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + random_samples: (num_selected_samples,) A tensor of samples. The + length of samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + # Find the maximum n value of the prompt phase requests. + random_samples = random_samples.cpu() + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + num_parent_seqs = len(seq_ids) + if is_prompt: + # Prompt phase. + parent_ids = [0] * sampling_params.n + next_token_ids = random_samples[ + sample_idx, :sampling_params.n].tolist() + else: + # Generation phase. + parent_ids = list(range(num_parent_seqs)) + next_token_ids = random_samples[sample_idx:sample_idx + + num_parent_seqs, 0].tolist() + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, + seq_groups: Optional[list[SequenceGroupToSample]] = None, +) -> torch.Tensor: + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + q = torch.empty_like(probs) + if seq_groups is None: + q.exponential_() + else: + sample_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + q[sample_idx:sample_idx + + stride].exponential_(generator=seq_group.generator) + sample_idx += stride + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + +def _top_k_top_p_multinomial_with_flashinfer( + probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, + num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + top_ks = top_ks.repeat_interleave(num_samples) + top_ps = top_ps.repeat_interleave(num_samples) + batch_next_token_ids = flashinfer_top_k_top_p_sampling( + probs, + top_ks, + top_ps, + ) + return batch_next_token_ids.view(-1, num_samples) + + +def get_pythonized_sample_results( + sample_result_args: SampleResultArgsType) -> SampleResultType: + '''This function consumes GPU-side sampler results and computes + Pythonized CPU-side sampler results (GPU -> CPU sync.) + + Single-step scheduling: this function is invoked at sampling-time + for immediate Pythonization. + + Multi-step scheduling: Pythonization is deferred until after multiple + GPU-side steps have been completed. + + Args: + sample_result_args: GPU-side inputs to the Pythonization process + + Returns: + Pythonized sampler results + ''' + + ( + sample_metadata, + sampling_metadata, + greedy_samples, + multinomial_samples, + sample_results_dict, + ) = ( + sample_result_args.sample_metadata, + sample_result_args.sampling_metadata, + sample_result_args.greedy_samples, + sample_result_args.multinomial_samples, + sample_result_args.sample_results_dict, + ) + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, + multinomial_samples[sampling_type]) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + return [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + + +def _sample_with_torch( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + + categorized_seq_group_ids: dict[SamplingType, list[int]] = { + t: [] + for t in SamplingType + } + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + sampling_params = seq_group.sampling_params + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) + else: + sampled_token_ids_tensor = None + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: + greedy_samples = torch.argmax(logprobs[long_sample_indices], + dim=-1) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = greedy_samples.unsqueeze(-1) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace(logprobs, probs, + long_sample_indices, + greedy_samples) + + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + max_n_in_batch = 1 + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params + max_n_in_batch = max(max_n_in_batch, sampling_params.n) + seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else + seq_groups) + + if flashinfer_top_k_top_p_sampling is not None: + logger.warning("FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation.") + + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_n_in_batch, + seq_groups=seq_groups_arg) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[long_sample_indices] = \ + multinomial_samples[sampling_type].to(torch.long) + + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + sample_results_dict=sample_results_dict) + + if not sampling_metadata.skip_sampler_cpu_output: + # GPU<->CPU sync happens here. + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor + else: + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> SampleReturnType: + """ + Args: + probs: (num_query_tokens_in_batch, num_vocab) + logprobs: (num_query_tokens_in_batch, num_vocab) + sampling_metadata: The metadata for a batch for sampling. + sampling_tensors: Tensors that include sampling related metadata. + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. + """ + return _sample_with_torch( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=include_gpu_probs_tensor, + modify_greedy_probs=modify_greedy_probs, + ) + + +def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + """ + This function calculates the ranks of the chosen tokens in a logprob tensor. + + Args: + x (torch.Tensor): 2D logprob tensor of shape (N, M) + where N is the no. of tokens and M is the vocab dim. + indices (torch.Tensor): List of chosen token indices. + + Returns: + torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. + Each element in the returned tensor represents the rank + of the chosen token in the input logprob tensor. + """ + vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), + indices] + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) + + +def get_logprobs( + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sample_results: SampleResultType, +) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: + """Return sample logprobs and prompt logprobs. + + The logic consists of 3 parts. + - Select indices to compute logprob from, ranks of token ids, and + the top k token ids from logprobs. + - Compute prompt logprobs if required. + - Compute sample logprobs if required. + + Args: + logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's + logprob per vocab. Sequence groups' query tokens are batched in a + single flattened tensor. For example, assuming there are N + seq groups, it is sorted by prefill tokens for seq_group_1 (if + prompt logprob is enabled), decode tokens for seq_group_1 (if + sampling is required), prefill tokens for seq_group_2, ... + sampling_metadata: The sampling metadata. + sample_results: (num_seq_groups) The tuple of (next_token_ids, + parent_ids) for each sequence group. When beam search is enabled, + sample_results can contain different number of seq_ids from + sampling_metadata.seq_groups. It is because beam search creates + 2 * BEAM_WIDTH number of samples (whereas there are only up to + BEAM_WIDTH number of seq_ids). + + Returns: + A tuple of prompt and sample logprobs per sequence group in a batch. + """ + # The index of query token to calculate logprobs. It includes both + # prompt and sample logprob indices. + query_indices: list[int] = [] + # The next token ids to get the logprob value from. + next_token_ids: list[int] = [] + # The largest requested number of logprobs. We find logprobs as many as the + # largest num logprobs in this API. If every logprobs is None, it will be + # set to -1. + largest_num_logprobs = -1 + + # Select indices to compute logprob from, ranks of token ids, and the top + # k token ids from logprobs. + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + sample_results): + sampling_params = seq_group.sampling_params + + # Update indices and tokens for prompt logprobs. + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.prompt_logprobs) + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + query_indices.extend(seq_group.prompt_logprob_indices) + next_token_ids.extend(next_prompt_tokens) + + # Update indices and next tokenes for sample logprob. + if seq_group.do_sample: + token_ids, parent_seq_ids = sample_result + # NOTE: We cannot directly use sample_indices because + # sample_indices only contain parent seq_ids of a previous step. + # The current step may have different number of seq_ids, and + # we can obtain it from `sample_result[1]`. + query_idx = seq_group.sample_indices[0] + query_indices.extend( + [query_idx + parent_id for parent_id in parent_seq_ids]) + next_token_ids.extend(token_ids) + + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + assert len(next_token_ids) == len(query_indices) + + if len(query_indices) == 0: + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None + num_seq_groups = len(sampling_metadata.seq_groups) + return [empty_prompt_logprob + ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups + + selected_logprobs, ranks = None, None + top_logprobs, top_token_ids = None, None + + # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can + # skip the whole logprob calculation. + if largest_num_logprobs >= 0: + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, + device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, + ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] + + # We need to compute top k only if there exists logprobs > 0. + if largest_num_logprobs > 0: + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). + top_logprobs, top_token_ids = torch.topk(logprobs, + largest_num_logprobs, + dim=-1) + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') + + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + + # Find prompt/sample logprobs. + prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: list[SampleLogprobs] = [] + top_logprob_idx = 0 + selected_logprobs_idx = 0 + + for seq_group, sample_result in zip(sampling_metadata.seq_groups, + sample_results): + (prompt_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_prompt_logprob_if_needed( + seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, + selected_logprobs_idx, top_logprob_idx) + prompt_logprobs_per_seq_group.append(prompt_logprobs) + + (sampled_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_sampled_logprob_if_needed( + seq_group, sample_result, selected_logprobs, ranks, top_token_ids, + top_logprobs, selected_logprobs_idx, top_logprob_idx) + sample_logprobs_per_seq_group.append(sampled_logprobs) + + return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group + + +def _get_prompt_logprob_if_needed( + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the prompt logprob from a sequence group if needed.""" + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + + # Find prompt logprobs + prompt_logprobs: Optional[PromptLogprobs] = None + if is_prompt and sampling_params.prompt_logprobs is not None: + prompt_logprobs = [] + num_logprobs = sampling_params.prompt_logprobs + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): + # Calculate the prompt logprob of the real prompt tokens. + # {token_id: (logprob, rank_from_vocab)} + prompt_logprobs_dict: dict[int, tuple[float, int]] = { + token_id: (selected_logprob_items[idx], rank_items[idx]) + } + + # Add top K prompt logprobs along with its rank. + if num_logprobs > 0: + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) + prompt_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in prompt_logprobs_dict.items() + }) + # + 1 to go to the next prompt token. + top_logprob_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) + return prompt_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _get_sampled_logprob_if_needed( + seq_group: SequenceGroupToSample, + sample_result: tuple[list[int], list[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the sample logprob if needed.""" + seq_ids = seq_group.seq_ids + num_logprobs = seq_group.sampling_params.logprobs + sampled_logprobs: SampleLogprobs = [] + next_token_ids, parent_seq_ids = sample_result + + if seq_group.do_sample: + assert len(next_token_ids) > 0 + if num_logprobs is None: + for next_token_id in next_token_ids: + # Use a dummy logprob + sampled_logprobs.append({next_token_id: Logprob(inf)}) + else: + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, parent_id) in enumerate( + zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: + (selected_logprob_items[idx], rank_items[idx]) + } + if num_logprobs is not None and num_logprobs > 0: + # Get top K logprobs. + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip( + top_ids, top_probs, top_ranks) + }) + + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() + }) + + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. + top_logprob_idx += len(seq_ids) + return sampled_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor) -> None: + """Modify the probability distributions of the greedily-sampled tokens such + that each sampled token has a "probability" of 1.0. This is required by + speculative decoding, which depends on the sampling method being encoded + within the probability distribution for correctness. + + # Why do we only need to do this for greedy sampling? + + vLLM's sampler performs the following steps for greedy or multinomial + (random) sampling: + 1. Get logits from model. + 2. Modify logits according to per-sequence sampling parameters. + - Multiply by temperature, top-k and top-p masking, penalize tokens + according to their frequency, etc. + 3. Sample a token. + - Random sampling simply samples from the modified probability + distribution. + - Greedy sampling performs `argmax` to obtain the token with the + highest likelihood. + + Ignoring greedy sampling for a moment, we find that the computed probability + distribution has the following property: we can sample from it independently + and find that the token sampled by the Sampler has a frequency corresponding + to how often we see it in our sampling. In other words, for tokens sampled + with vLLM's random SamplingType, the computed probability distribution + encodes the sampling methodology completely. + + Greedy sampling does not normally have this property. vLLM modifies logits + according to sampling params, then performs `argmax`, then returns the + sampled token and the computed probability distribution. If we sample from + the distribution, we'll find the likelihood of the greedily-sampled token + is not always 1.0. + + Since lossless speculative decoding requires that the sampling methodology + be encoded within the probability distribution, we are motivated to modify + the probability distribution such that the sampled token has probability 1 + when speculative decoding is used. + + NOTE: Alternatively, we could use an extremely low temperature to achieve + greedy sampling using multinomial computation and unite the codepaths. This + has implications on the overall design of the sampler, e.g. how to record + accurate logprobs for the user, so this improvement is deferred to later. + """ + # NOTE: logprobs are not modified so they can be returned to the user. + probs[sample_indices, :] = 0 + probs[sample_indices, greedy_samples] = 1.0 + + +def _build_sampler_output( + maybe_deferred_sample_results: MaybeDeferredSampleResultType, + sampling_metadata: SamplingMetadata, + prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], + sample_logprobs: Optional[list[SampleLogprobs]], + on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], + skip_sampler_cpu_output: bool = False, + logits: Optional[torch.Tensor] = None +) -> SamplerOutput: + """Construct Python objects with the output of sampling. + + Args: + on_device_tensors: Tuple containing on-device tensors with the + probabilities used in sampling and the sampled token ids. This + allows post-processing without copies to CPU/serialization, e.g. in + speculative decoding rejection sampling. + """ + sampler_output: list[CompletionSequenceGroupOutput] = [] + + if skip_sampler_cpu_output: + assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) + deferred_sample_results_args = maybe_deferred_sample_results + else: + assert prompt_logprobs is not None + assert sample_logprobs is not None + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + assert len(sampling_metadata.seq_groups) \ + == len(maybe_deferred_sample_results) \ + == len(prompt_logprobs) \ + == len(sample_logprobs) + deferred_sample_results_args = None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + maybe_deferred_sample_results, + prompt_logprobs, sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: list[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) + + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors + else: + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) + + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, + deferred_sample_results_args=deferred_sample_results_args, + logits=logits) + + +def _get_next_prompt_tokens( + seq_group: SequenceGroupToSample) -> tuple[int, ...]: + """Get a list of next prompt tokens to compute logprob from a + given sequence group. + + It is used to compute prompt logprob. Imagine you have logprob for each + query token. Query token needs to know the next prompt token id to compute + prompt logprob. This is a helper to obtain next prompt token ids. + + This API has to be used only when the caller knows seq_group is in prefill + stage. + + Returns: + A list of next prompt tokens to compute logprob. + """ + assert seq_group.is_prompt, ( + "Caller should ensure the sequence group is in a prefill stage.") + seq_ids = seq_group.seq_ids + query_len = seq_group.query_len + assert query_len is not None + # prompt has only 1 seq id. + assert len(seq_ids) == 1 + seq_data = seq_group.seq_data[seq_ids[0]] + computed_len = seq_data.get_num_computed_tokens() + prompt_tokens = seq_data.prompt_token_ids + # +1 because we are looking for a next prompt token. + next_token_index_start = computed_len + 1 + next_token_index_end = min(computed_len + query_len + 1, + len(prompt_tokens)) + next_prompt_tokens = prompt_tokens[ + next_token_index_start:next_token_index_end] + return next_prompt_tokens diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 0000000..0a36fe9 --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from typing import Optional, Union + +import torch +import torch.jit +import torch.nn as nn + +from vllm.platforms import current_platform + + +class SpecDecodeBaseSampler(nn.Module): + """Base class for samplers used for Speculative Decoding verification + step. + """ + + def __init__(self, strict_mode: bool = False): + """Base class constructor. + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, device: Union[int, str]) -> None: + assert self.num_accepted_tokens is None + if isinstance(device, int): + device = f"{current_platform.device_type}:{device}" + elif not isinstance(device, str): + raise ValueError(f"Device must be int or str, get {type(device)}") + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + def init_tensors(self, + device: Union[int, str], + device_type: Union[torch.device, str] = 'cuda') -> None: + assert self.num_accepted_tokens is None + if isinstance(device_type, torch.device): + device_type = device_type.type + if isinstance(device, int): + device = f"{device_type}:{device}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. + + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] + """ + batch_size, k = substitute_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze(-1) + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + substitute_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_incorrect_dtype(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_inconsistent_device(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], + draft_token_ids, bonus_token_ids) + + def _raise_if_incorrect_shape( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_with_bonus_probs.shape + + # Does not count the extra token + num_target_probs -= 1 + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_with_bonus_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in [ + target_with_bonus_probs, bonus_token_ids, draft_probs, + draft_token_ids + ] if t is not None + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + + +class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are deterministic. + """ + + @abstractmethod + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are stochastic + """ + + @abstractmethod + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + seeded_seqs: Optional[dict[int, torch.Generator]] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 0000000..932c4a4 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Optional, List +import torch +import torch.jit +import torch.nn.functional as F + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeDeterministicBaseSampler) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 + """ + + def __init__( + self, + posterior_threshold: float, + posterior_alpha: float, + strict_mode: bool = False, + ): + """Create a Typical Acceptance Sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + on the posterior probability of a token in target model for it + to be accepted. + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. + """ + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha + super().__init__(strict_mode=strict_mode) + + self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1') + + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + cart_candidates: Optional[torch.Tensor] = None, + best_candidates: Optional[List] = None, + accept_lengths: Optional[List] = None, + first_step_flags: Optional[List] = None, + ) -> torch.Tensor: + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one token will be emitted. + + In the case where all draft tokens are accepted, the bonus token will be + accepted. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: This parameter is unused by the acceptance sampler. + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + cart_candidates: tree-style cartesian candidates + best_candidates: pending to write best candidates index + accept_lengths: pending to write accept lengths + first_step_flags: whether this is the first decoding step + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids) + + if not self.tree_decoding: + target_probs = target_with_bonus_probs[:, :-1] + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) + recovered_token_ids = self._get_recovered_token_ids(target_probs) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) + else: + assert cart_candidates is not None + target_probs = target_with_bonus_probs + output_token_ids = self._evaluate_accepted_tokens_tree_style(target_probs, + draft_token_ids, + cart_candidates, + best_candidates, + accept_lengths, + first_step_flags) + return output_token_ids + + def _evaluate_accepted_tokens_tree_style(self, target_probs, draft_token_ids, + cart_candidates, output_best_candidates, + accept_lengths, first_step_flags): + r""" + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed + token ids. + cart_candidates : torch.Tensor + A tensor of shape (batch_size, retrieve_size, tree_depth) + representing the cart candidates of tree proposals. + + A draft token_id x_{n+k} is accepted if it satisfies the + following condition + + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. + + Returns: + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. True indicates acceptance and false indicates + rejection. + + """ + target_probs = target_probs[:, :, :-1] + device = target_probs.device + batch_size = cart_candidates.shape[0] + candidates_prob = torch.gather( + target_probs, dim=-1, index=cart_candidates[:, :, 1:].unsqueeze(-1) + ).squeeze(-1) # [batch_size, retrieve_size, max_depth] + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + 1e-5), dim=-1 + ) # torch.sum(torch.log(*)) is faster than torch.prod [batch_size, retrieve_size, max_depth] + threshold = torch.minimum( + torch.ones_like(posterior_entropy) * self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) + posterior_mask = candidates_prob > threshold # [batch_size, retrieve_size, max_depth] + candidates_accept_length = (torch.cumprod(posterior_mask, dim=2)).sum(dim=-1) # [batch_size, retrieve_size] + + # Choose the best candidate based on the evaluated posterior probabilities + accept_length, _ = candidates_accept_length.max(dim=-1) # [batch_size] + if torch.any(accept_length > 0): + valid_index = (candidates_accept_length == accept_length.unsqueeze(-1)).unsqueeze(-1) # [batch_size, retrieve_size, 1] + + candidates_prob = candidates_prob * valid_index # [batch_size, retrieve_size, max_depth] + valid_index = torch.arange(candidates_prob.shape[-1], device=device).unsqueeze(0).unsqueeze(0).repeat( + batch_size, candidates_prob.shape[1], 1) # [batch_size, retrieve_size, max_depth] + valid_index = (valid_index < accept_length.unsqueeze(1).unsqueeze(2).repeat(1, candidates_prob.shape[1], 1)) # [batch_size, retrieve_size, 1] + candidates_prob = candidates_prob*valid_index # [batch_size, retrieve_size, max_depth] + + # add 1e-3 to avoid zero value + likelihood = torch.sum(torch.log(candidates_prob + 1e-3), dim=-1) # [batch_size, retrieve_size] + + best_candidate = torch.argmax(likelihood, dim=-1) # [batch_size] + else: + # Choose the best candidate + best_candidate = torch.zeros((batch_size), dtype=torch.long, device=device) # [batch_size] + + k = draft_token_ids.shape[-1] + output_token_id_list = [] + + accept_length_list = accept_length.cpu().tolist() + #logger.info("accept_length:%s", accept_length_list) + for i in range(batch_size): + output_best_candidates.append(best_candidate[i]) + accept_lengths.append(accept_length_list[i]) + + if not first_step_flags[i]: + select_indices = cart_candidates[i, best_candidate[i], : accept_length[i] + 1] + select_indices = F.pad(select_indices, (0, k - 1 - accept_length[i]), 'constant', -1) + else: + select_indices = cart_candidates[i, best_candidate[i], 1 : accept_length[i] + 1] + select_indices = F.pad(select_indices, (0, k - accept_length[i]), 'constant', -1) + output_token_id_list.append(select_indices) + + return torch.stack(output_token_id_list, dim=0) + + + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. + + Args: + target_probs (torch.Tensor): A tensor of shape + (batch_size, k, vocab_size) representing the probabilities of + each token in the vocabulary for each position in the proposed + sequence. This is the distribution generated by the target + model. + draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k) + representing the proposed token ids. + + A draft token_id x_{n+k} is accepted if it satisfies the + following condition + + $$ + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + $$ + + where $p_{\text{original}}$ corresponds to target_probs + and $\epsilon$ and $\delta$ correspond to hyperparameters + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. + + Returns: + torch.Tensor: A boolean tensor of shape (batch_size, k) where each + element indicates whether the corresponding draft token has + been accepted or rejected. True indicates acceptance and false + indicates rejection. + """ + device = target_probs.device + candidates_prob = torch.gather( + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + # A small constant added to prevent computing the logarithm of zero, + # which can lead to undefined values. + epsilon = 1e-5 + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + epsilon), dim=-1) + threshold = torch.minimum( + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) + accepted_mask = candidates_prob > threshold + return accepted_mask + + def _get_recovered_token_ids(self, target_probs): + """ + The recovered token ids will fill the first unmatched token + by the target token. + + Args: + target_probs (torch.Tensor): A tensor of shape + (batch_size, k, vocab_size) containing the target probability + distribution. + + Returns: + torch.Tensor: A tensor of shape (batch_size, k) with the recovered + token ids which are selected from target probs. + """ + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py new file mode 100644 index 0000000..939d7df --- /dev/null +++ b/vllm/model_executor/layers/utils.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility methods for model layers.""" +from typing import Callable, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm import envs +from vllm.platforms import current_platform + + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + """ + Applies penalties in place to the logits tensor + logits : The input logits tensor of shape [num_seqs, vocab_size] + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID + in the vocabulary. + output_tokens_tensor: The output tokens tensor. + presence_penalties: The presence penalties of shape (num_seqs, ) + frequency_penalties: The frequency penalties of shape (num_seqs, ) + repetition_penalties: The repetition penalties of shape (num_seqs, ) + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, + vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + # Apply repetition penalties as a custom op + from vllm._custom_ops import apply_repetition_penalties + apply_repetition_penalties(logits, prompt_mask, output_mask, + repetition_penalties) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask + return logits + + +def default_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + return torch.nn.functional.linear(x, weight, bias) + + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + from vllm.platforms.rocm import on_gfx9 + k = weight.shape[1] + use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ + x.dtype in [torch.float16, torch.bfloat16] \ + and k % 8 == 0 and bias is None) + + if use_skinny is not True: + return torch.nn.functional.linear(x, weight, bias) + + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] + cu_count = current_platform.get_cu_count() + + if m > 8 and 0 < n <= 4: + out = ops.wvSplitK(weight, x_view, cu_count) + return out.view(*x.shape[:-1], weight.shape[0]) + elif m % 4 == 0 and n == 1 and k <= 8192: + out = ops.LLMM1(weight, x_view, 4) + return out.view(*x.shape[:-1], weight.shape[0]) + return torch.nn.functional.linear(x, weight, bias) + + +def cpu_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + if getattr(layer, "use_cpu_sgl", False): + return torch.ops._C.weight_packed_linear(x, weight, bias, True) + else: + return torch.nn.functional.linear(x, weight, bias) + + +def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: + if current_platform.is_rocm(): + # return rocm_unquantized_gemm + return torch.nn.functional.linear + elif current_platform.is_cpu(): + return cpu_unquantized_gemm + else: + return default_unquantized_gemm diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000..4ba2a18 --- /dev/null +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -0,0 +1,545 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from dataclasses import dataclass + +import vllm.envs as envs +import os + +from typing import Optional, Sequence + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.utils import dispatch_unquantized_gemm +from vllm.model_executor.parameter import BasevLLMParameter +from vllm.model_executor.utils import set_weight_attrs + +from vllm.platforms import current_platform +from vllm.utils import SUPPORT_TC + + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def __init__(self): + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for embedding layer.""" + # if envs.VLLM_USE_NN: + # weight = Parameter(torch.empty(input_size_per_partition, + # sum(output_partition_sizes), + # dtype=params_dtype), + # requires_grad=False) + # else: + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_llama_nn and os.environ['LM_NN'] == '1': + if bias is not None: + if len(x.shape) == 2: + return torch.addmm(bias, x, layer.weight) + else: + return torch.matmul(x, layer.weight) + bias + else: + return torch.matmul(x, layer.weight) + else: + if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]: + return dispatch_unquantized_gemm()(x, layer.weight.t(), bias) + else: + return dispatch_unquantized_gemm()(x, layer.weight, bias) + + + def embedding(self, layer: torch.nn.Module, + input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, + rank: int, + offset: int = 0) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, + rank: int, + world_size: int, + offset: int = 0) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, + offset=offset) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return (self.padded_org_vocab_end_index - + self.padded_org_vocab_start_index) + + @property + def num_added_elements_padded(self) -> int: + return (self.padded_added_vocab_end_index - + self.padded_added_vocab_start_index) + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert (self.padded_org_vocab_start_index + <= self.padded_org_vocab_end_index) + assert (self.padded_added_vocab_start_index + <= self.padded_added_vocab_end_index) + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert (self.added_vocab_start_index + <= self.padded_added_vocab_start_index) + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded + + +if SUPPORT_TC: + @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) + def get_masked_input_and_mask( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask +else: + def get_masked_input_and_mask( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & ( + input_ < org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1025 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1010 | ... | 1025 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 528 | ... | 543 | + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + quant_config: quant config for the layer + prefix: full name of the layer in the state dict + """ # noqa: E501 + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + # Keep the input dimensions. + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method)) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.quant_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce + self.tbo_all_reduce = tbo_all_reduce + + @classmethod + def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, + vocab_size: int, org_vocab_size: int, tp_rank: int, + tp_size: int) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = ( + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, + tp_size)) + padded_added_vocab_start_index, padded_added_vocab_end_index = ( + vocab_range_from_global_vocab_size(num_added_embeddings_padded, + tp_rank, + tp_size, + offset=org_vocab_size)) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, + org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, + vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, padded_org_vocab_end_index, + padded_added_vocab_start_index, padded_added_vocab_end_index, + org_vocab_start_index, org_vocab_end_index, + added_vocab_start_index, added_vocab_end_index) + + def get_sharded_to_full_mapping(self) -> Optional[list[int]]: + """Get a mapping that can be used to reindex the gathered + logits for sampling. + + During sampling, we gather logits from all ranks. The relationship + of index->token_id will follow the same format as outlined in the class + docstring. However, after the gather, we want to reindex the final + logits tensor to map index->token_id one-to-one (the index is always + equal the token_id it corresponds to). The indices returned by this + method allow us to do that. + """ + if self.tp_size < 2: + return None + + base_embeddings: list[int] = [] + added_embeddings: list[int] = [] + padding: list[int] = [] + for tp_rank in range(self.tp_size): + shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + range_start = self.num_embeddings_per_partition * tp_rank + range_end = self.num_embeddings_per_partition * (tp_rank + 1) + base_embeddings.extend( + range(range_start, + range_start + shard_indices.num_org_elements)) + padding.extend( + range(range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded)) + added_embeddings.extend( + range( + range_start + shard_indices.num_org_elements_padded, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements)) + padding.extend( + range( + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded)) + assert (range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded == range_end) + ret = base_embeddings + added_embeddings + padding + assert len(ret) == self.num_embeddings_padded + return ret + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If the parameter is a gguf weight, then load it directly. + if getattr(param, "is_gguf_weight_type", None): + param.data.copy_(loaded_weight) + param.weight_type = loaded_weight.item() + return + elif isinstance(param, UninitializedParameter): + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = self.num_embeddings_per_partition + param.materialize(tuple(shape), dtype=loaded_weight.dtype) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = param.packed_factor if isinstance( + param, BasevLLMParameter) else param.pack_factor + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // + param.packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. Select chunk corresponding to current shard. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # if envs.VLLM_USE_NN and self.quant_method is not None: + # loaded_weight = loaded_weight.t() + + if current_platform.is_hpu(): + # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here, + # so we're using a workaround. Remove this when fixed in + # HPU PT bridge. + padded_weight = torch.cat([ + loaded_weight, + torch.zeros(param.shape[0] - loaded_weight.shape[0], + *loaded_weight.shape[1:]) + ]) + param.data.copy_(padded_weight) + else: + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + if envs.VLLM_ENABLE_TBO: + output = self.tbo_all_reduce(output_parallel) + else: + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + + +class ParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size, quant_config, + prefix) + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") \ No newline at end of file diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py new file mode 100644 index 0000000..78681a0 --- /dev/null +++ b/vllm/model_executor/model_loader/__init__.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from torch import nn + +from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.bitsandbytes_loader import ( + BitsAndBytesModelLoader) +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader +from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader +from vllm.model_executor.model_loader.runai_streamer_loader import ( + RunaiModelStreamerLoader) +from vllm.model_executor.model_loader.sharded_state_loader import ( + ShardedStateLoader) +from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader +from vllm.model_executor.model_loader.utils import ( + get_architecture_class_name, get_model_architecture, get_model_cls) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER: + return RunaiModelStreamerLoader(load_config) + + if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED: + return ShardedStateLoader(load_config, runai_model_streamer=True) + + return DefaultModelLoader(load_config) + + +def get_model(*, + vllm_config: VllmConfig, + model_config: Optional[ModelConfig] = None) -> nn.Module: + loader = get_model_loader(vllm_config.load_config) + if model_config is None: + model_config = vllm_config.model_config + return loader.load_model(vllm_config=vllm_config, + model_config=model_config) + + +__all__ = [ + "get_model", + "get_model_loader", + "get_architecture_class_name", + "get_model_architecture", + "get_model_cls", + "BaseModelLoader", + "BitsAndBytesModelLoader", + "GGUFModelLoader", + "DefaultModelLoader", + "DummyModelLoader", + "RunaiModelStreamerLoader", + "ShardedStateLoader", + "TensorizerLoader", +] diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py new file mode 100644 index 0000000..5018c7d --- /dev/null +++ b/vllm/model_executor/model_loader/base_loader.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + """Load weights into a model. This standalone API allows + inplace weights loading for an already-initialized model""" + raise NotImplementedError + + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: + """Load a model with the given configurations.""" + device_config = vllm_config.device_config + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) + # Quantization does not happen in `load_weights` but after it + self.load_weights(model, model_config) + process_weights_after_loading(model, model_config, target_device) + return model.eval() diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py new file mode 100644 index 0000000..8e330f7 --- /dev/null +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -0,0 +1,613 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: SIM117 +import fnmatch +import glob +import itertools +import math +import os +from collections.abc import Generator +from typing import Any, Callable, Optional + +import numpy as np +import torch +from huggingface_hub import HfApi +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +# yapf: enable +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import (ParamMapping, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models import is_pooling_model +from vllm.model_executor.utils import (get_packed_modules_mapping, + set_weight_attrs) +from vllm.platforms import current_platform + +# yapf conflicts with isort for this block + +logger = init_logger(__name__) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # Save the module names without sharding. + self.unsharded_weights_modules: list[str] = [] + # Save the module names that are sharded by column. + self.column_sharded_weights_modules: list[str] = [] + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: dict[str, list[int]] = {} + # Store all module names (from transformers) that support + # BNB quantization. + self.target_modules: list[str] = [] + # mapping weight names from transformers to vllm. + self.weight_mapper: Callable = lambda name: name + self.pre_quant: bool = False + self.load_8bit: bool = False + self.is_pool_model: bool = False + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: list[str], + revision: Optional[str] = None, + ) -> tuple[str, list[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return model_name_or_path, weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return hf_folder, glob.glob( + os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> tuple[list[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + use_safetensors = matched_pattern == "*.safetensors" + is_local = os.path.isdir(model_name_or_path) + index_file = SAFE_WEIGHTS_INDEX_NAME + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, use_safetensors + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + + def _maybe_pool_model(module_name: str): + # For pool model, we need to add the prefix `model.` + # for the weight name if possible. + if self.is_pool_model and self.target_modules[0]. \ + startswith("model.") and not module_name.startswith( + "model."): + return "model." + module_name + + return module_name + + if use_safetensors: + iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + ) + for org_name, param in iterator: + # mapping weight names from transformers to vllm while preserving + # original names. + mapped_name = self.weight_mapper(org_name) + mapped_name = _maybe_pool_model(mapped_name) + + yield org_name, mapped_name, param + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + + if bitsandbytes.__version__ < "0.46.1": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict: dict[str, Any] = {} + + if self.pre_quant: + if self.load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) + for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if not mapped_weight_name.lower().endswith(".scb"): + continue + + weight_key = mapped_weight_name.lower().replace(".scb", ".weight") + quant_state_dict[weight_key] = weight_tensor + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_8bit_weight_name(mapped_weight_name): + continue + + if mapped_weight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in weight_iterator: + if not self._is_4bit_weight_name(mapped_weight_name): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in mapped_weight_name: + temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[mapped_weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, + device=current_platform.device_type) + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if self._is_4bit_weight_name(mapped_weight_name): + continue + + if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" + in temp_state_dict) or ( + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" + in temp_state_dict): + quant_state = _parse_quant_state(mapped_weight_name, + temp_state_dict) + quant_state_dict[mapped_weight_name] = quant_state + yield org_weight_name, weight_tensor + else: + yield org_weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for ( + org_weight_name, + mapped_weight_name, + weight_tensor, + ) in self._hf_weight_iter(hf_weights_files, use_safetensors): + if any(target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): + # Without sharding + if any( + mapped_weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any( + mapped_weight_name.startswith(module) + for module in self.column_sharded_weights_modules): + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + # Weights have fused on disk. In this case, we assume that the + # weight and module use same name. + elif any( + mapped_weight_name.startswith(module) + for module in self.maybe_fused_weights_modules): + # special case for fused weights + # get the size of each shard weight tensor + total_shard_sizes = next( + (sizes for module, sizes in + self.maybe_fused_weights_modules.items() + if mapped_weight_name.startswith(module))) + total_size = weight_tensor.size(0) + assert total_size == sum(total_shard_sizes) + # get the start/end index of each shard weight tensor + total_start_index = list( + itertools.accumulate([0] + total_shard_sizes))[:-1] + shard_weights_index = [( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) for idx, size in zip(total_start_index, + total_shard_sizes)] + # slice and reorder the weight tensor + weight_tensor = [ + weight_tensor[start_index:end_index, ...] + for start_index, end_index in shard_weights_index + ] + weight_sub_tensor = torch.cat(weight_tensor, dim=0) + # Shard by row + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4", + ) + + quant_state_dict[mapped_weight_name] = quant_state + else: + processed_weight = weight_tensor + yield org_weight_name, processed_weight + + def _get_bnb_target_modules(self, model: nn.Module) -> None: + """ + Identify and collect all modules that support BitsAndBytes + quantization. + """ + for name, module in model.named_modules(): + if (isinstance(module, LinearBase) + and hasattr(module.quant_method, "quant_config")): + if modules_info := self.modules_mapping.get_sub_modules(name): + # Map vllm's names to transformers's names. + rep_name, sub_modules = modules_info + for sub_name in sub_modules: + self.target_modules.append( + name.replace(rep_name, sub_name)) + # Add original module name even if the module has stacked map, + # in case model has a mixture of disk-merged and disk-split + # weights with same last name. + self.target_modules.append(name) + + assert (self.target_modules + ), "vllm currently does not support BNB quantization for" + f" {type(model).__name__}" + + def _classify_module_sharding(self, model: nn.Module): + """ + Categorize modules based on their weight sharding requirements + for tensor parallelism. + """ + for name, module in model.named_modules(): + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + if isinstance(module, (ReplicatedLinear, )): + self.unsharded_weights_modules.append(name) + # `QKVParallelLinear` and `MergedColumnParallelLinear` might have + # fused weights on disk. We need to use the output sizes of these + # modules to shard the weights correctly. + elif isinstance(module, + (QKVParallelLinear, MergedColumnParallelLinear)): + self.maybe_fused_weights_modules[name] = module.output_sizes + # In TP, these weights are partitioned along the column + # dimension (dim=-1) + elif isinstance(module, (RowParallelLinear, )): + self.column_sharded_weights_modules.append(name) + + def _verify_model_compatibility(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Verify that the model is compatible with BitsAndBytes quantization. + """ + if not hasattr(model, "load_weights"): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") + + if not hasattr(model, "packed_modules_mapping"): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet. No 'packed_modules_mapping' found.") + + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if quant_config is not None: + quant_method = quant_config.get("quant_method") + if quant_method == "bitsandbytes": + self.pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if self.pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with tensor parallelism is not " + "supported. Please try with pipeline parallelism.") + if self.pre_quant: + self.load_8bit = quant_config.get("load_in_8bit", False) + + def _initialize_loader_state(self, model: nn.Module, + model_config: ModelConfig) -> None: + """ + Initialize the loader's internal state based on the model and + configuration. + """ + self.is_pool_model = is_pooling_model(model) + self.modules_mapping = ParamMapping(get_packed_modules_mapping(model)) + + # For some models like Molmo, we need to use hf_to_vllm_mapper + # to ensure correct loading of weights. + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name) + + self._get_bnb_target_modules(model) + self._classify_module_sharding(model) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + + self._verify_model_compatibility(model, model_config) + self._initialize_loader_state(model, model_config) + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + )) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: dict[str, dict[int, Any]] = {} + # TODO: Change this lazy import to normal import + # after the checks are updated to run on a new version + from vllm.model_executor.models.utils import is_pp_missing_parameter + + for quant_param_name in quant_state_dict: + if is_pp_missing_parameter(quant_param_name, model): + continue + + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, + index, + ) in self.modules_mapping.inverse_packed_mapping.items(): + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight + shard_pos = quant_param_name.find(shard_name) + can_correct_rename = (shard_pos + > 0) and (quant_param_name[shard_pos - 1] + == ".") + # If the quant_param_name is packed, it won't occur in the + # param_dict before renaming. + new_quant_param_name = quant_param_name.replace( + shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) \ + and (new_quant_param_name in param_dict) + if can_correct_rename and need_rename: + shard_index = index + quant_param_name = new_quant_param_name + break + + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. + if quant_param_name not in param_dict: + continue + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + # Dequantize double quantized values during weight loading. + dequantize_dq(quant_states) + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = (math.prod(quant_state.shape) // + pack_ratio) + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + # Make torch infer_schema happy + offsets = torch.tensor(offsets).cpu() + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if self.load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + torch.cuda.empty_cache() + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + +def dequantize_dq(quant_states: dict) -> None: + """ + When BNB employs Double Quantization, we perform the dequantization of + these constants during weight loading rather than at inference time, + thereby avoiding this computational overhead during inference. This comes + at the cost of increased memory usage. + """ + from bitsandbytes.functional import QuantState, dequantize_blockwise + for _, quant_state in quant_states.items(): + # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 + if isinstance(quant_state, QuantState) and quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, + quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + quant_state.absmax = absmax + quant_state.nested = False + quant_state.offset = None + quant_state.state2 = None diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py new file mode 100644 index 0000000..4624ff0 --- /dev/null +++ b/vllm/model_executor/model_loader/default_loader.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import glob +import os +import time +from collections.abc import Generator, Iterable +from typing import Optional, cast + +import huggingface_hub +import torch +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm import envs +from vllm.config import LoadConfig, LoadFormat, ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: Optional[list[str]] = None + """If defined, weights will load exclusively using these patterns.""" + + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, self.load_config.download_dir): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants. + HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + allow_patterns_overrides: Optional[list[str]], + ) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif (load_format == LoadFormat.SAFETENSORS + or load_format == LoadFormat.FASTSAFETENSORS): + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt, + source.allow_patterns_overrides) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + elif use_safetensors: + if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: + weights_iterator = fastsafetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + else: + weights_iterator = pt_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + self.load_config.pt_load_map_location, + ) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + + elif current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore + + def _hpu_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + htcore.mark_step() + + weights_iterator = _hpu_weights_iterator(weights_iterator) + + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", + None), + ) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self.get_all_weights(model_config, model)) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py new file mode 100644 index 0000000..f4a7da5 --- /dev/null +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch.nn as nn + +from vllm.config import LoadConfig, ModelConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.weight_utils import ( + initialize_dummy_weights) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py new file mode 100644 index 0000000..203c807 --- /dev/null +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from collections.abc import Generator + +import gguf +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + get_gguf_extra_tensor_names, gguf_quant_weights_iterator) + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + gguf_to_hf_name_map = {} + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + if model_type in ("deepseek_v3", "deepseek_v2"): + model_type = "deepseek2" + # GGUF layer map assumes that we will have a merged expert weights + # so we need to map them manually + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config( + config, trust_remote_code=model_config.trust_remote_code) + state_dict = dummy_model.state_dict() + + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str] + ) -> Generator[tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, + gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map)) + + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: + device_config = vllm_config.device_config + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map): + model_config.hf_config.update({"tie_word_embeddings": True}) + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + self.load_weights(model, model_config) + + process_weights_after_loading(model, model_config, target_device) + return model diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py new file mode 100644 index 0000000..fad97ab --- /dev/null +++ b/vllm/model_executor/model_loader/neuron.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utilities for selecting and loading Neuron models in transformers-neuronx +framework.""" +import ast +import copy +import importlib +import os +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "f32", + "half": "f16", + "float16": "f16", + "bfloat16": "bf16", + "float": "f32", + "float32": "f32", + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", +} + +# Models supported by Neuron. +_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = { + "LlamaForCausalLM": ("transformers_neuronx.llama.model", + "LlamaForSampling", "LlamaForCausalLM"), + "MistralForCausalLM": ("transformers_neuronx.mistral.model", + "MistralForSampling", "MistralForCausalLM") +} + + +class NeuronCausalLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + logits = self.model(input_ids, + cache_ids=positions, + start_ids=input_block_ids) + return logits + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + if self.on_device_sampling_disabled: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + # On-device sampling outputs the token ids directly. + sampled_token_ids = logits.flatten() + next_tokens = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + samples = [] + for seq_id in seq_group.seq_ids: + token_id = sampled_token_ids[sample_idx].item() + samples.append( + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + next_tokens.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + + return SamplerOutput(outputs=next_tokens) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + + self.model = neuronx_model_cls.from_pretrained(model_name_or_path, + **kwargs) + self.model.to_neuron() + + +class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" + + SPECULATION_TERMINATION_ID = -1 + + def __init__(self, speculation_model) -> None: + super().__init__() + self.model = speculation_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + tokens, counts = self.model.speculative_iteration( + input_ids, positions, input_block_ids) + + # Mark the end of accepted speculative tokens for each sequence with the + # speculation termination id. + batch_size, steps = tokens.shape + mask = torch.arange(steps).expand(batch_size, -1) >= counts + tokens[mask] = self.SPECULATION_TERMINATION_ID + + return tokens + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[list[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == self.SPECULATION_TERMINATION_ID + for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + + +def _get_buckets(env: str, default_value: list[int]) -> list[int]: + env_value = os.getenv(env) + if env_value is None: + return default_value + buckets_remove_empty = filter( + lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) + buckets_int = map(int, buckets_remove_empty) + buckets_list = list(buckets_int) + return buckets_list + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + """Generate a neuron config based on vllm config args.""" + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + quant_config = dict( + dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + quantize_method="vector_dynamic") + neuron_quantization_config_builder = lambda quant: get_quantization_config( + quant).from_config(quant_config).get_quant_method(None, "") + # TODO: Add Paged attention config to the default neuron arguments. + default_neuron_args = dict( + collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + quant=neuron_quantization_config_builder(model_config.quantization) + if model_config.quantization else None, + continuous_batching=continuous_batching_config, + weight_tiling=bool(model_config.quantization), + on_device_generation=_get_neuron_on_device_generation_config( + model_config)) + return default_neuron_args + + +def _get_default_neuron_config_for_speculation( + model_config: ModelConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + """Generate a neuron config for speculative decoding based on + vllm config args.""" + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + + default_neuron_args = dict(collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + on_device_embedding=True, + continuous_batching=continuous_batching_config, + on_device_generation=copy.deepcopy( + model_config.neuron_sampling_params)) + return default_neuron_args + + +def _get_neuron_on_device_generation_config(model_config: ModelConfig): + if not _is_neuron_on_device_sampling_disabled(model_config): + return copy.deepcopy(model_config.neuron_sampling_params) + return None + + +def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: + return not getattr(model_config, "neuron_sampling_params", None) + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + from transformers_neuronx.config import (ContinuousBatchingConfig, + GenerationConfig, + KVCacheQuantizationConfig, + NeuronConfig, QuantizationConfig, + SparseAttnConfig) + + sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) + if sparse_attn: + overridden_neuron_config["sparse_attn"] = SparseAttnConfig( + **sparse_attn) + + kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {}) + if kv_cache_quant: + overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig( + **kv_cache_quant) + + continuous_batching = overridden_neuron_config.pop("continuous_batching", + {}) + if continuous_batching: + overridden_neuron_config[ + "continuous_batching"] = ContinuousBatchingConfig( + **continuous_batching) + + quant = overridden_neuron_config.pop("quant", {}) + if quant: + overridden_neuron_config["quant"] = QuantizationConfig(**quant) + + on_device_generation = overridden_neuron_config.pop( + "on_device_generation", {}) + if on_device_generation: + overridden_neuron_config["on_device_generation"] = GenerationConfig( + **on_device_generation) + default_neuron_config.update(overridden_neuron_config) + return NeuronConfig(**default_neuron_config) + + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + """Initializes a neuron-optimized model for inference.""" + # Create a model instance. + model = NeuronCausalLM( + model_config.hf_config, + _is_neuron_on_device_sampling_disabled(model_config)) + + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + model.load_weights(model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + return model.eval() + + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This method is only applicable for speculation with a standalone draft model + """ + from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder + + # For Eagle SD, we need to pass in additional parameters in neuron config. + is_eagle = getattr(speculation_config.draft_model_config.hf_config, + "is_eagle", False) + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + if is_eagle: + default_neuron_config_args['is_eagle_target'] = True + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) + if is_eagle: + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + num_speculative_tokens = speculation_config.num_speculative_tokens + # Create speculation model instance. + speculation_model = FusedSpeculativeDecoder(draft_model.model, + target_model.model, + num_speculative_tokens) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) + + +def get_neuron_eagle_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized EAGLE speculation model for inference.""" + from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + default_neuron_config_args['is_eagle_target'] = True + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + token_tree: dict[int, list[int]] = ast.literal_eval( + speculation_config.speculative_token_tree) + + speculation_model = EagleSpeculativeDecoder(draft_model.model, + target_model.model, + token_tree=token_tree) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py new file mode 100644 index 0000000..f450961 --- /dev/null +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utilities for selecting and loading Neuron models in +neuronx-distributed-inference framework.""" +# Disabling yapf because yapf and isort have conflicts for the below imports +# yapf: disable +import copy +import hashlib +import importlib +import multiprocessing +import os +import shutil +from typing import Optional + +import torch +import torch.nn as nn +from neuronx_distributed_inference.models.config import ( + FusedSpecNeuronConfig, OnDeviceSamplingConfig) +from neuronx_distributed_inference.models.mllama.utils import ( + create_vision_mask) +from neuronx_distributed_inference.modules.lora_serving import ( + LoraServingConfig) +from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config) +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig + +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +# yapf: enable +logger = init_logger(__name__) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "float32", + "half": "float16", + "float16": "float16", + "bfloat16": "bfloat16", + "float": "float32", + "float32": "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", +} + +# Models supported by Neuronx distributed for inference. +_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = { + "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "MistralForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "DbrxForCausalLM": + ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", + "NeuronDbrxForCausalLM"), + "MixtralForCausalLM": + ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", + "NeuronMixtralForCausalLM"), + "MllamaForConditionalGeneration": + ("neuronx_distributed_inference.models.mllama.modeling_mllama", + "NeuronMllamaForCausalLM"), +} + + +class NeuronCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + prev_hidden: Optional[torch.Tensor] = None, + adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=sorted_input_block_ids, + sampling_params=sampling_params, + prev_hidden=prev_hidden, + adapter_ids=adapter_ids) + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + output = output.hidden_states + else: + output = output.logits[:, -1, :] + + restored_indices = torch.argsort(sorted_indices) + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + + return output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + batch_size = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.flatten() + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + step_output_token_ids = [] + for i, seq_id in enumerate(seq_ids): + token_id = accepted_token_ids_by_step[i] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + return SamplerOutput(outputs=step_output_token_ids) + else: + return self.sampler(logits, sampling_metadata) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + self.config.neuron_config = neuron_config + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + + +class NeuronMllamaForCausalLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + # has_image is the only multimodal input that is used in + # token-generation + # This is a cache (on CPU) that saves has_image data per sequence id + # The number of entries in this cache is <= Batch-Size + self.has_image_cache: dict[int, torch.Tensor] = {} + self.config = config + self.logits_processor = LogitsProcessor( + config.get_text_config().vocab_size, logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + self.is_reorder_needed: bool = True + + def read_from_has_image_cache(self, seq_ids: torch.Tensor): + has_image_list = [] + for index in range(len(seq_ids)): + seq_id = seq_ids[index].item() + if seq_id in self.has_image_cache: + has_image_list.append(self.has_image_cache[seq_id]) + else: + has_image_list.append(torch.tensor([0])) + return torch.tensor(has_image_list) + + def write_to_has_image_cache(self, seq_ids: torch.Tensor, + has_image: torch.Tensor): + for index in range(len(seq_ids)): + seq_id = seq_ids[index].item() + if index < len(has_image): + self.has_image_cache[seq_id] = has_image[index] + else: + self.has_image_cache[seq_id] = torch.zeros(1) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + seq_ids: torch.Tensor, pixel_values: torch.Tensor, + aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, + has_image: torch.Tensor, sampling_params) -> torch.Tensor: + + # We update the has_image cache during prefill + # and read the has_image cache during decode + if input_ids.shape[-1] > 1: # prefill + self.write_to_has_image_cache(seq_ids, has_image) + else: + has_image = self.read_from_has_image_cache(seq_ids) + bs = input_ids.shape[0] + num_chunks = torch.zeros((bs, 1)) + aspect_ratios = torch.zeros((bs, 1, 2)) + + input_block_ids = seq_ids + origin_input_block_ids = seq_ids + if self.is_reorder_needed: + # sort block ids sequentially for perf/neuron support reasons + input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + pixel_values = torch.index_select(pixel_values, 0, sorted_indices) + aspect_ratios = torch.index_select(aspect_ratios, 0, + sorted_indices) + num_chunks = torch.index_select(num_chunks, 0, sorted_indices) + has_image = torch.index_select(has_image, 0, sorted_indices) + + self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) + output = self.model( + input_ids.to(torch.int32), + attention_mask=None, + position_ids=positions.to(torch.int32), + seq_ids=seq_ids.flatten().to(torch.int32), + pixel_values=pixel_values.to( + self.config.vision_config.torch_dtype), + aspect_ratios=aspect_ratios.to(torch.int32), + vision_mask=self.vision_mask.to(torch.int32), + sampling_params=sampling_params, + num_chunks=num_chunks.to(torch.int32), + has_image=has_image.to(torch.int32), + ) + if self.config.neuron_config.on_device_sampling_config: + output = output.hidden_states + else: + output = output.logits[:, -1, :] + + if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1: + restored_indices = torch.argsort(sorted_indices) + output = torch.index_select(output, 0, restored_indices) + return output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample(self, hidden_states, sampling_metadata): + if not self.on_device_sampling_disabled: + with torch.profiler.record_function("sample"): + hidden_states = hidden_states.flatten() + res = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + samples = [] + for seq_id in seq_ids: + token_id = hidden_states[sample_idx].item() + samples.append( + SequenceOutput( + parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + res.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + next_tokens = SamplerOutput(outputs=res) + else: + next_tokens = self.sampler(None, hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + self.config.neuron_config = neuron_config + logger.info("neuron_config buckets: %s", + self.config.neuron_config.buckets) + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + try: + self.model = neuronx_model_cls(compiled_model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.vision_token_id = tokenizer( + "<|image|>", add_special_tokens=False).input_ids[0] + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError): + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + + logger.info("\nCompiling and saving model to %s", model_name_or_path) + + p = multiprocessing.Process(target=compile_model, + args=(self, compiled_model_path)) + p.start() + p.join() + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(compiled_model_path) + logger.info("Successfully compiled and saved the model in %s", + compiled_model_path) + + # Read "<|image|>" token_id from the tokenizer + self.vision_token_id = tokenizer("<|image|>", + add_special_tokens=False).input_ids[0] + logger.info("\nLoading model from compiled checkpoint...") + self.model.load(compiled_model_path) + + +def compile_model(neuron_model, traced_model_path): + neuron_model.model.compile(traced_model_path) + + +class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + # sort block ids sequentially for perf/neuron support reasons + sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=sorted_input_block_ids, + sampling_params=sampling_params) + restored_indices = torch.argsort(sorted_indices) + + # CTX encoding + if (positions[:, 0]).sum().item() == 0: + output = output.fused_outputs[0][:, 0:1] + if input_block_ids.shape[0] != 1: + output = torch.index_select(output, 0, restored_indices) + return output + + # Fused Spec (Generation) + accepted_tokens_with_padding = output.fused_outputs[0] + next_pos_ids = output.fused_outputs[-1] + generated_token_counts = next_pos_ids - positions + + assert torch.any(generated_token_counts == 0).item() is False, \ + "NxDI model generated no output for one or more sequences." + + batch_size, steps = accepted_tokens_with_padding.shape + mask = torch.arange(steps).expand(batch_size, + -1) >= generated_token_counts + accepted_tokens_with_padding[mask] = -1 + + if input_block_ids.shape[0] != 1: + accepted_tokens_with_padding = torch.index_select( + accepted_tokens_with_padding, 0, restored_indices) + + return accepted_tokens_with_padding + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[list[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def load_weights(self, model_name_or_path: str, + draft_model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + + draft_neuron_config = copy.deepcopy(config.neuron_config) + if not config.neuron_config.enable_eagle_speculation: + draft_neuron_config.speculation_length = 0 + draft_neuron_config.trace_tokengen_model = True + draft_neuron_config.enable_fused_speculation = False + if getattr(config.neuron_config, "draft_model_modules_to_not_convert", + None): + draft_neuron_config.modules_to_not_convert = ( + draft_neuron_config.draft_model_modules_to_not_convert) + if config.neuron_config.enable_eagle_speculation: + draft_neuron_config.is_eagle_draft = True + draft_neuron_config.sequence_parallel_enabled = False + draft_config = neuronx_model_cls.get_config_cls()( + draft_neuron_config, + load_config=load_pretrained_config(draft_model_name_or_path)) + fused_spec_config = (FusedSpecNeuronConfig( + neuronx_model_cls._model_cls, + draft_config=draft_config, + draft_model_path=draft_model_name_or_path)) + config.fused_spec_config = fused_spec_config + self.config.neuron_config = neuron_config + + hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'), + usedforsecurity=False).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + if not os.path.exists(draft_model_name_or_path): + if draft_model_name_or_path != model_name_or_path: + hf_model = AutoModelForCausalLM.from_pretrained( + draft_model_name_or_path) + saved_path = os.path.join("local-models", + draft_model_name_or_path) + hf_model.save_pretrained(saved_path) + draft_model_name_or_path = saved_path + else: + draft_model_name_or_path = model_name_or_path + config.fused_spec_config.draft_model_path = draft_model_name_or_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_serving_config: LoraServingConfig): + """Generate a neuron config based on vllm config args.""" + on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, + deterministic=False) + batch_size = scheduler_config.max_num_seqs + + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, + batch_size=batch_size, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + enable_bucketing=True, + is_continuous_batching=True, + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + padding_side="right", + on_device_sampling_config=on_device_sampling_config, + sequence_parallel_enabled=True, + lora_serving_config=lora_serving_config) + return neuron_config + + +def _get_default_speculation_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Generate a neuron config for speculative decoding based on vllm config + args.""" + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, + batch_size=scheduler_config.max_num_seqs, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + speculation_length=speculation_config.num_speculative_tokens, + trace_tokengen_model=False, + enable_fused_speculation=True, + enable_bucketing=True, + is_continuous_batching=True, + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + on_device_sampling_config=dict( + top_k=1, + do_sample=False, + )) + return neuron_config + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + """Update default neuron config values with override args""" + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return default_neuron_config + + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_serving_config: LoraServingConfig) -> nn.Module: + """Initializes a neuron-optimized model for inference.""" + model_arch = _get_model_architecture(model_config.hf_config) + if model_arch == "MllamaForConditionalGeneration": + model = NeuronMllamaForCausalLM(model_config.hf_config) + else: + model = NeuronCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config, lora_serving_config) + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config + model.load_weights(model_config.model, + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) + return model.eval() + + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This model handles speculation using both a draft model and an EAGLE draft. + """ + model = NeuronSpeculationCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_speculation_config( + model_config, parallel_config, scheduler_config, speculation_config) + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config + model.load_weights(model_config.model, + speculation_config.draft_model_config.model, + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) + return model.eval() diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py new file mode 100644 index 0000000..83e0f38 --- /dev/null +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: SIM117 +import glob +import os +from collections.abc import Generator +from typing import Optional + +import torch +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import LoadConfig, ModelConfig +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + runai_safetensors_weights_iterator) +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 + + +class RunaiModelStreamerLoader(BaseModelLoader): + """ + Model loader that can load safetensors + files from local FS or S3 bucket. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + extra_config = load_config.model_loader_extra_config + + if ("concurrency" in extra_config + and isinstance(extra_config.get("concurrency"), int)): + os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( + extra_config.get("concurrency")) + + if ("memory_limit" in extra_config + and isinstance(extra_config.get("memory_limit"), int)): + os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( + extra_config.get("memory_limit")) + + runai_streamer_s3_endpoint = os.getenv( + 'RUNAI_STREAMER_S3_ENDPOINT') + aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') + if (runai_streamer_s3_endpoint is None + and aws_endpoint_url is not None): + os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> list[str]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + + is_s3_path = is_s3(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) + safetensors_pattern = "*.safetensors" + index_file = SAFE_WEIGHTS_INDEX_NAME + + hf_folder = (model_name_or_path if + (is_local or is_s3_path) else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + )) + if is_s3_path: + hf_weights_files = s3_glob(path=hf_folder, + allow_pattern=[safetensors_pattern]) + else: + hf_weights_files = glob.glob( + os.path.join(hf_folder, safetensors_pattern)) + + if not is_local and not is_s3_path: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, self.load_config.download_dir, + revision) + + if not hf_weights_files: + raise RuntimeError( + f"Cannot find any safetensors model weights with " + f"`{model_name_or_path}`") + + return hf_weights_files + + def _get_weights_iterator( + self, model_or_path: str, + revision: str) -> Generator[tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_weights_files = self._prepare_weights(model_or_path, revision) + return runai_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + ) + + def download_model(self, model_config: ModelConfig) -> None: + """Download model if necessary""" + self._prepare_weights(model_config.model, model_config.revision) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + """Load weights into a model.""" + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + model.load_weights( + self._get_weights_iterator(model_weights, model_config.revision)) diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py new file mode 100644 index 0000000..2fd9cfb --- /dev/null +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import collections +import glob +import os +from collections.abc import Generator +from typing import Any, Optional + +import torch +from torch import nn + +from vllm.config import LoadConfig, ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, runai_safetensors_weights_iterator) +from vllm.transformers_utils.s3_utils import glob as s3_glob +from vllm.transformers_utils.utils import is_s3 + +logger = init_logger(__name__) + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/offline_inference/save_sharded_state.py` for creating a sharded + checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, + load_config: LoadConfig, + runai_model_streamer: bool = False): + super().__init__(load_config) + + self.runai_model_streamer = runai_model_streamer + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list)) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + from vllm.distributed import get_tensor_model_parallel_rank + + model_weights = model_config.model + if hasattr(model_config, "model_weights"): + model_weights = model_config.model_weights + local_model_path = model_weights + + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + + filepaths = [] + if is_s3(local_model_path): + file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" + filepaths = s3_glob(path=local_model_path, + allow_pattern=[file_pattern]) + else: + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for key, tensor in self.iterate_over_files(filepaths): + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + + def iterate_over_files( + self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: + if self.runai_model_streamer: + yield from runai_safetensors_weights_iterator(paths, True) + else: + from safetensors.torch import safe_open + for path in paths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + yield key, tensor + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py new file mode 100644 index 0000000..1c14d55 --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -0,0 +1,602 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import contextlib +import contextvars +import dataclasses +import io +import json +import os +import threading +import time +from collections.abc import Generator +from dataclasses import dataclass +from functools import partial +from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union + +import regex as re +import torch +from torch import nn +from torch.utils._python_dispatch import TorchDispatchMode +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.config import (ModelConfig, ParallelConfig, VllmConfig, + set_current_vllm_config) +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.utils import FlexibleArgumentParser, PlaceholderModule + +if TYPE_CHECKING: + from vllm.engine.arg_utils import EngineArgs + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) +except ImportError: + tensorizer = PlaceholderModule("tensorizer") + DecryptionParams = tensorizer.placeholder_attr("DecryptionParams") + EncryptionParams = tensorizer.placeholder_attr("EncryptionParams") + TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer") + TensorSerializer = tensorizer.placeholder_attr("TensorSerializer") + open_stream = tensorizer.placeholder_attr("stream_io.open_stream") + convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes") + get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") + no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") + + _read_stream = tensorizer.placeholder_attr("_read_stream") + _write_stream = tensorizer.placeholder_attr("_write_stream") + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor', 'TensorizerConfig' +] + +logger = init_logger(__name__) + + +class MetaTensorMode(TorchDispatchMode): + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + if func._schema.name == "aten::empty" and "device" not in kwargs: + kwargs["device"] = "meta" + + return func(*args, **kwargs) + + +def meta_tensor_mode(loading_code=None, ): + + if loading_code is None: + return _NoInitOrTensorImpl.context_manager() + elif callable(loading_code): + with _NoInitOrTensorImpl.context_manager(): + return loading_code() + else: + raise TypeError( + "expected a callable to evaluate," + " or None if being used as a context manager;" + f' got an object of type "{type(loading_code).__name__}" instead.') + + +class _NoInitOrTensorImpl: + _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) + _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) + + is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", + default=False) + _count_active: int = 0 + _count_active_lock = threading.Lock() + + @classmethod + @contextlib.contextmanager + def context_manager(cls): + if cls.is_active.get(): + yield + return + + with cls._count_active_lock: + cls._count_active += 1 + if cls._count_active == 1: + for mod in cls._MODULES: + mod.reset_parameters = cls._disable(mod.reset_parameters) + + reset_token = cls.is_active.set(True) + + try: + with MetaTensorMode(): + yield + finally: + cls.is_active.reset(reset_token) + with cls._count_active_lock: + cls._count_active -= 1 + if cls._count_active == 0: + for mod, original in cls._MODULE_ORIGINALS: + mod.reset_parameters = original + + @staticmethod + def _disable(func): + + def wrapper(*args, **kwargs): + if not _NoInitOrTensorImpl.is_active.get(): + return func(*args, **kwargs) + + return wrapper + + +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[str, None] = None + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[type[torch.nn.Module]] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Optional[Union[str, torch.dtype]] = None + lora_dir: Optional[str] = None + _is_sharded: bool = False + + def __post_init__(self): + # check if the configuration is for a sharded vLLM model + self._is_sharded = isinstance(self.tensorizer_uri, str) \ + and re.search(r'%0\dd', self.tensorizer_uri) is not None + if not self.tensorizer_uri and not self.lora_dir: + raise ValueError("tensorizer_uri must be provided.") + if not self.tensorizer_uri and self.lora_dir: + self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" + assert self.tensorizer_uri is not None, ("tensorizer_uri must be " + "provided.") + self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) + self.lora_dir = self.tensorizer_dir + + @classmethod + def as_dict(cls, *args, **kwargs) -> dict[str, Any]: + cfg = TensorizerConfig(*args, **kwargs) + return dataclasses.asdict(cfg) + + def to_dict(self) -> dict[str, Any]: + return dataclasses.asdict(self) + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) # type: ignore + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if parallel_config.tensor_parallel_size > 1 \ + and not self._is_sharded: + raise ValueError( + "For a sharded model, tensorizer_uri should include a" + " string format template like '%04d' to be formatted" + " with the rank of the shard") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): + if tensorizer_args is None: + tensorizer_args = self._construct_tensorizer_args() + + return open_stream(self.tensorizer_uri, + **tensorizer_args.stream_params) + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/others/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (self.s3_secret_access_key + or envs.S3_SECRET_ACCESS_KEY) + self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Tensorizer CLI arguments""" + + # Tensorizer options arg group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) + + group.add_argument( + "--tensorizer-uri", + type=str, + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + type=str, + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=None, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file. Default is `None`, which will dynamically " + "set the number of readers based on the available resources " + "and model size. This greatly increases performance.") + group.add_argument( + "--s3-access-key-id", + type=str, + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + type=str, + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + type=str, + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + attrs = [attr.name for attr in dataclasses.fields(cls)] + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +def _check_tensors_on_meta_device(model: nn.Module) -> None: + for tensor in model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + +def _resize_lora_embeddings(model: nn.Module): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in model.modules(): + if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0] + < child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + +def init_tensorizer_model(tensorizer_config: TensorizerConfig, + vllm_config: VllmConfig) -> nn.Module: + assert tensorizer_config.hf_config is not None + model_args = tensorizer_config.hf_config + model_args.torch_dtype = tensorizer_config.dtype + assert tensorizer_config.model_class is not None + # TODO: Do we need to consider old-style model class? + with meta_tensor_mode(), set_current_vllm_config(vllm_config, + check_compile=True): + return tensorizer_config.model_class(vllm_config=vllm_config) + + +def deserialize_tensorizer_model(model: nn.Module, + tensorizer_config: TensorizerConfig) -> None: + tensorizer_args = tensorizer_config._construct_tensorizer_args() + before_mem = get_mem_usage() + start = time.perf_counter() + with _read_stream( + tensorizer_config.tensorizer_uri, + **tensorizer_args.stream_params) as stream, TensorDeserializer( + stream, + dtype=tensorizer_config.dtype, + device=f'cuda:{torch.cuda.current_device()}', + **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) + + _check_tensors_on_meta_device(model) + _resize_lora_embeddings(model) + del model.vllm_tensorized_marker + + +def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" +) -> Generator[tuple[str, torch.Tensor], None, None]: + logger.warning("Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the " + "examples/others/tensorize_vllm_model.py example script " + "for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + yield from state.items() + del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + return ".vllm_tensorized_marker" in deserializer + + +def serialize_vllm_model( + model: nn.Module, + tensorizer_config: TensorizerConfig, +) -> nn.Module: + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + + encryption_params = None + if (keyfile := tensorizer_config.encryption_keyfile) is not None: + with open(keyfile, "rb") as f: + key = f.read() + encryption_params = EncryptionParams(key=key) + + output_file = tensorizer_args.tensorizer_uri + if tensorizer_config._is_sharded: + from vllm.distributed import get_tensor_model_parallel_rank + output_file = output_file % get_tensor_model_parallel_rank() + + with _write_stream(output_file, **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", str(output_file)) + return model + + +def tensorize_vllm_model(engine_args: "EngineArgs", + tensorizer_config: TensorizerConfig, + generate_keyfile: bool = True): + """Utility to load a model and then serialize it with Tensorizer + + Intended to be used separately from running a vLLM server since it + creates its own Engine instance. + """ + engine_config = engine_args.create_engine_config() + tensorizer_config.verify_with_model_config(engine_config.model_config) + tensorizer_config.verify_with_parallel_config( + engine_config.parallel_config) + + # generate the encryption key before creating the engine to support sharding + if generate_keyfile and (keyfile := + tensorizer_config.encryption_keyfile) is not None: + encryption_params = EncryptionParams.random() + with _write_stream( + keyfile, + s3_access_key_id=tensorizer_config.s3_access_key_id, + s3_secret_access_key=tensorizer_config.s3_secret_access_key, + s3_endpoint=tensorizer_config.s3_endpoint, + ) as stream: + stream.write(encryption_params.key) + + from vllm import LLMEngine + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + + if not envs.VLLM_USE_V1: + engine = LLMEngine.from_engine_args(engine_args) + engine.model_executor.collective_rpc( + "save_tensorized_model", + kwargs=dict(tensorizer_config=tensorizer_config), + ) + else: + engine = V1LLMEngine.from_vllm_config(engine_config) + engine.collective_rpc( + "save_tensorized_model", + kwargs=dict(tensorizer_config=tensorizer_config), + ) + + +def tensorize_lora_adapter(lora_path: str, + tensorizer_config: TensorizerConfig): + """ + Uses tensorizer to serialize a LoRA adapter. Assumes that the files + needed to load a LoRA adapter are a safetensors-format file called + adapter_model.safetensors and a json config file called adapter_config.json. + + Serializes the files in the tensorizer_config.lora_dir + """ + import safetensors + + from vllm.lora.utils import get_adapter_absolute_path + + lora_dir = get_adapter_absolute_path(lora_path) + + tensor_path = config_path = "" + + for file in os.listdir(lora_dir): + if file.startswith("adapter_model"): + tensor_path = lora_dir + "/" + file + if file.startswith("adapter_config"): + config_path = lora_dir + "/" + file + if tensor_path and config_path: + break + + if tensor_path.endswith(".safetensors"): + tensors = safetensors.torch.load_file(tensor_path) + elif tensor_path.endswith(".bin"): + tensors = torch.load(tensor_path) + else: + raise ValueError("Unsupported file: %s", tensor_path) + + with open(config_path) as f: + config = json.load(f) + + tensorizer_args = tensorizer_config._construct_tensorizer_args() + + with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", + mode="wb+", + **tensorizer_args.stream_params) as f: + + f.write(json.dumps(config).encode("utf-8")) + + lora_uri = (f"{tensorizer_config.lora_dir}" + f"/adapter_model.tensors") + with open_stream(lora_uri, mode="wb+", + **tensorizer_args.stream_params) as f: + serializer = TensorSerializer(f) + serializer.write_state_dict(tensors) + serializer.close() + + logger.info("Successfully serialized LoRA files to %s", + str(tensorizer_config.lora_dir)) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py new file mode 100644 index 0000000..0b62e74 --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: SIM117 +import copy +from collections.abc import Generator +from typing import Union + +import torch +from torch import nn + +from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, + is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + initialize_model, + set_default_torch_dtype) + +logger = init_logger(__name__) + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self, ) -> Generator[tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_serialized_cpu( + self, + vllm_config: VllmConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer to the CPU. + + This is only necessary when the model isn't vLLM-tensorized (see + examples/others/tensorize_vllm_model.py) This should still + be faster than default HuggingFace loading, but will be slower than + loading a vLLM-tensorized model. + """ + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = initialize_model(vllm_config=vllm_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + + def _patch_tensorizer_config( + self, model_config: ModelConfig) -> TensorizerConfig: + model_class = get_model_architecture(model_config)[0] + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + return tensorizer_config + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + """Load serialized model weights with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/others/tensorize_vllm_model.py example script + for serializing vLLM models.""" + if is_vllm_tensorized(self.tensorizer_config): + tensorizer_config = self._patch_tensorizer_config(model_config) + deserialize_tensorizer_model(model, tensorizer_config) + else: + model.load_weights(self._get_weights_iterator()) + + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: + parallel_config = vllm_config.parallel_config + self._verify_config(model_config, parallel_config) + + if parallel_config.tensor_parallel_size > 1: + from vllm.distributed import get_tensor_model_parallel_rank + + self.tensorizer_config.tensorizer_uri = ( + self.tensorizer_config.tensorizer_uri % + get_tensor_model_parallel_rank()) + + if is_vllm_tensorized(self.tensorizer_config): + tensorizer_config = self._patch_tensorizer_config(model_config) + device_config = vllm_config.device_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = init_tensorizer_model( + tensorizer_config=tensorizer_config, + vllm_config=vllm_config) + self.load_weights(model, model_config) + return model + return self._load_model_serialized_cpu(vllm_config=vllm_config) + + @staticmethod + def save_model( + model: torch.nn.Module, + tensorizer_config: Union[TensorizerConfig, dict], + ) -> None: + if isinstance(tensorizer_config, dict): + tensorizer_config = TensorizerConfig(**tensorizer_config) + serialize_vllm_model( + model=model, + tensorizer_config=tensorizer_config, + ) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py new file mode 100644 index 0000000..b44c165 --- /dev/null +++ b/vllm/model_executor/model_loader/tpu.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from typing import Optional + +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.distributed.spmd as xs + +from vllm.config import ModelConfig, VllmConfig +from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model +from vllm.logger import init_logger +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) + +logger = init_logger(__name__) + + +class TPUModelLoader(DefaultModelLoader): + """ + A TPU model loader for model loading under SPMD mode. + """ + + def load_model( + self, + vllm_config: VllmConfig, + model_config: ModelConfig, + mesh: Optional[xs.Mesh] = None, + ) -> nn.Module: + # Initialize model and load weights on CPU. Then, during SPMD partition, + # weights are sharded and transferred to TPUs. + self.counter_before_loading_weights = time.perf_counter() + model_config = vllm_config.model_config + assert model_config.quantization is None, "Quantization not supported" + target_device = torch.device('cpu') + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config) + + load_format = vllm_config.load_config.load_format + if load_format != "dummy": + weights_to_load = { + name + for name, _ in model.named_parameters() + } + all_weights = self.get_all_weights(model_config, model) + loaded_weights = model.load_weights(all_weights) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights - + self.counter_before_loading_weights) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if model_config.quantization is None and \ + loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + else: + logger.info("Use dummy weight during weight loading.") + + process_weights_after_loading(model, model_config, target_device) + + counter_before_partition = time.perf_counter() + model = model.eval() + model = model.to('xla') + shard_model(model, mesh) + counter_after_partition = time.perf_counter() + logger.info("Partition model took %.2f seconds", + counter_after_partition - counter_before_partition) + + # Ensure the model is properly loaded. + self._check_model_is_loaded(mesh, model) + + # Need to torch compile after model sharding are done. Because the + # compiler hints ('xs.mark_sharding') are torch ops. + if not model_config.is_multimodal_model: + model.model = torch.compile(model.model, backend="openxla") + else: + model.language_model.model = \ + torch.compile(model.language_model.model, backend="openxla") + return model + + def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], + model: nn.Module) -> None: + """ + Ensure the model is properly loaded. + 1. All model parameters and buffers are on XLA device. + 2. Non-SPMD friendly layers are replaced as expected. + """ + device = xm.xla_device() + device_type = str(device.type) + + # Check parameters + for name, param in model.named_parameters(): + assert param.device.type == device_type, f"Parameter {name} is on \ + {param.device.type} instead of {device_type}" + + # Check buffers + for name, buffer in model.named_buffers(): + assert buffer.device.type == device_type, \ + f"Buffer {name} is on {buffer.device.type} instead of \ + {device_type}" + + for module in model.modules(): + if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): + raise AssertionError("QKVParallelLinear should be replaced by \ + XlaQKVParallelLinear under SPMD mode.") diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py new file mode 100644 index 0000000..ff6b346 --- /dev/null +++ b/vllm/model_executor/model_loader/utils.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utilities for selecting and loading models.""" +import contextlib +import inspect +import warnings +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Optional + +import os +import torch +import transformers +from torch import nn +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from vllm.attention import Attention +from vllm.config import (ModelConfig, ModelImpl, VllmConfig, + set_current_vllm_config) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVCrossParallelLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.adapters import (as_embedding_model, + as_reward_model) + +import vllm.envs as envs +from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.utils import is_pin_memory_available + +logger = init_logger(__name__) + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", + model_class: Optional[type[nn.Module]] = None, + model_config: Optional[ModelConfig] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + if model_config is None: + model_config = vllm_config.model_config + if model_class is None: + model_class, _ = get_model_architecture(model_config) + + if vllm_config.quant_config is not None: + configure_quant_config(vllm_config.quant_config, model_class) + + signatures = inspect.signature(model_class.__init__) + all_params = [param.name for param in signatures.parameters.values()] + if "vllm_config" in all_params and "prefix" in all_params: + # new-style model class + with set_current_vllm_config(vllm_config, + check_compile=True, + prefix=prefix): + return model_class(vllm_config=vllm_config, prefix=prefix) + + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " + "for the design and update the model class accordingly.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + logger.warning( + "Trying to guess the arguments for old-style model class %s", + model_class, + ) + # try to be compatible with old-style model class + kwargs = {} + if "prefix" in all_params: + kwargs["prefix"] = prefix + if "config" in all_params: + kwargs["config"] = model_config.hf_config + if "cache_config" in all_params: + kwargs["cache_config"] = vllm_config.cache_config + if "quant_config" in all_params: + kwargs["quant_config"] = vllm_config.quant_config + if "lora_config" in all_params: + kwargs["lora_config"] = vllm_config.lora_config + if "scheduler_config" in all_params: + kwargs["scheduler_config"] = vllm_config.scheduler_config + with set_current_vllm_config(vllm_config, + check_compile=True, + prefix=prefix): + return model_class(**kwargs) + + +def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + for _, module in model.named_modules(): + if isinstance(module, QKVCrossParallelLinear): + # NOTE(Isotr0py): special case for cross QKV layer because + # q and kv proj aren't registered as submodules intentionally + module.process_weights_after_loading() + continue + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +def resolve_transformers_arch(model_config: ModelConfig, + architectures: list[str]): + for i, arch in enumerate(architectures): + if arch == "TransformersForCausalLM": + continue + auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", + None) or dict() + # Make sure that config class is always initialized before model class, + # otherwise the model class won't be able to access the config class, + # the expected auto_map should have correct order like: + # "auto_map": { + # "AutoConfig": "--", + # "AutoModel": "--", + # "AutoModelFor": "--", + # }, + auto_modules = { + name: + get_class_from_dynamic_module(module, + model_config.model, + revision=model_config.revision) + for name, module in sorted(auto_map.items(), key=lambda x: x[0]) + } + model_module = getattr(transformers, arch, None) + if model_module is None: + if "AutoModel" not in auto_map: + raise ValueError( + f"Cannot find model module. '{arch}' is not a registered " + "model in the Transformers library (only relevant if the " + "model is meant to be in Transformers) and 'AutoModel' is " + "not present in the model config's 'auto_map' (relevant " + "if the model is custom).") + model_module = auto_modules["AutoModel"] + # TODO(Isotr0py): Further clean up these raises. + # perhaps handled them in _ModelRegistry._raise_for_unsupported? + if model_config.model_impl == ModelImpl.TRANSFORMERS: + if not model_module.is_backend_compatible(): + raise ValueError( + f"The Transformers implementation of {arch} is not " + "compatible with vLLM.") + architectures[i] = "TransformersForCausalLM" + if model_config.model_impl == ModelImpl.AUTO: + if not model_module.is_backend_compatible(): + raise ValueError( + f"{arch} has no vLLM implementation and the Transformers " + "implementation is not compatible with vLLM. Try setting " + "VLLM_USE_V1=0.") + logger.warning( + "%s has no vLLM implementation, falling back to Transformers " + "implementation. Some features may not be supported and " + "performance may not be optimal.", arch) + architectures[i] = "TransformersForCausalLM" + return architectures + + +def get_model_architecture( + model_config: ModelConfig) -> tuple[type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) + # TODO: 'Qwen2_5_VLForConditionalGeneration', + support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM', + 'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM', + 'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] + if any(arch in architectures for arch in support_nn_architectures): + if not envs.VLLM_USE_NN: + if os.getenv('LLAMA_NN') != '0': + if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: + os.environ['LLAMA_NN'] = '0' + else: + os.environ['LLAMA_NN'] = '1' + if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0': + os.environ['LM_NN'] = '0' + else: + os.environ['LM_NN'] = '1' + if os.getenv('GEMM_PAD') != '1': + os.environ['GEMM_PAD'] = '0' + if os.getenv('FA_PAD') != '1': + os.environ['FA_PAD'] = '0' + # awq相关配置 + try: + if os.getenv('AWQ_MOE_SZ') == None: + os.environ['AWQ_MOE_SZ'] = '1' + if os.getenv('AWQ_PAD') == None and (torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120): + os.environ['AWQ_PAD'] = '1' + except Exception as e: + if os.getenv('AWQ_PAD') != '0': + os.environ['AWQ_PAD'] = '1' + else: + os.environ['AWQ_PAD'] = '0' + else: + os.environ['LLAMA_NN'] = '0' + os.environ['LM_NN'] = '0' + os.environ['GEMM_PAD'] = '0' + os.environ['FA_PAD'] = '0' + os.environ['AWQ_PAD'] = '0' + + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" + ] + + vllm_supported_archs = ModelRegistry.get_supported_archs() + vllm_not_supported = not any(arch in vllm_supported_archs + for arch in architectures) + if (model_config.model_impl == ModelImpl.TRANSFORMERS or + model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): + architectures = resolve_transformers_arch(model_config, architectures) + elif (model_config.quantization is not None + and model_config.quantization not in mixtral_supported + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) + if model_config.task == "embed": + model_cls = as_embedding_model(model_cls) + elif model_config.task == "classify": + # Cannot automatically run as_seq_cls_model, + # otherwise it will cause a circular reference on is_cross_encoder_model + pass + elif model_config.task == "reward": + model_cls = as_reward_model(model_cls) + + return model_cls, arch + + +def get_model_cls(model_config: ModelConfig) -> type[nn.Module]: + return get_model_architecture(model_config)[0] + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] + + +@dataclass +class ParamMapping: + """ + A class to handle parameter mapping for model weight loading. + It creates a bidirectional mapping between packed parameters and their + constituent parts. + """ + packed_mapping: dict[str, list[str]] + inverse_packed_mapping: dict[str, tuple[str, + int]] = field(default_factory=dict) + + def __post_init__(self): + for packed_name, sub_params in self.packed_mapping.items(): + # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]}) + if len(sub_params) == 1 and sub_params[0] == packed_name: + continue + for index, param_name in enumerate(sub_params): + self.inverse_packed_mapping[param_name] = ( + packed_name, + index, + ) + + def get_sub_modules(self, + module_name: str) -> Optional[tuple[str, list[str]]]: + for key, value in self.packed_mapping.items(): + if module_name.endswith(key): + return key, value + return None + + +def configure_quant_config(quant_config: QuantizationConfig, + model_class: type[nn.Module]): + """ + Pass packed_modules_mapping by reference to quant_config so that + quant_config can properly match fused modules + + Note that model attributes are passed by reference to quant_config, + enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) + + Once the `SupportsQuant` mixin has been added to all models, this + function can be removed + """ + if not issubclass(model_class, SupportsQuant): + hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None) + packed_mapping = getattr(model_class, "packed_modules_mapping", None) + + # pass mappings by reference to quant_config + if hf_to_vllm_mapper is not None: + quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + if packed_mapping is not None: + quant_config.packed_modules_mapping = packed_mapping diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py new file mode 100644 index 0000000..fd62853 --- /dev/null +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -0,0 +1,799 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import os +import tempfile +import time +from collections import defaultdict +from collections.abc import Generator +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import filelock +import gguf +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm + +from vllm.config import LoadConfig, ModelConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QuantizationConfig, + get_quantization_config) +from vllm.platforms import current_platform +from vllm.utils import PlaceholderModule + +try: + from runai_model_streamer import SafetensorsStreamer +except (ImportError, OSError): + # see https://github.com/run-ai/runai-model-streamer/issues/26 + # OSError will be raised on arm64 platform + runai_model_streamer = PlaceholderModule( + "runai_model_streamer") # type: ignore[assignment] + SafetensorsStreamer = runai_model_streamer.placeholder_attr( + "SafetensorsStreamer") + +try: + from fastsafetensors import SafeTensorsFileLoader, SingleGroup +except ImportError: + fastsafetensors = PlaceholderModule("fastsafetensors") + SafeTensorsFileLoader = fastsafetensors.placeholder_attr( + "SafeTensorsFileLoader") + SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer(): + """automatically activates hf_transfer + """ + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: Union[str, Path], + cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + model_name_or_path = str(model_name_or_path) + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), + mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu", weights_only=True) + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # Inflight BNB quantization + if model_config.quantization == "bitsandbytes": + return quant_cls.from_config({}) + is_local = os.path.isdir(model_config.model) + if not is_local: + # Download the config files. + with get_lock(model_config.model, load_config.download_dir): + hf_folder = snapshot_download( + model_config.model, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_config.model + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file) as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_config.model + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) + + +def get_sparse_attention_config( + model_config: ModelConfig, + load_config: LoadConfig, + sparse_attention_config_filename: str = "sparse_attention_config.json", +) -> dict[str, Any]: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + config_file = os.path.join(hf_folder, sparse_attention_config_filename) + if not os.path.exists(config_file): + return {} + + # Load the sparse attention config. + with open(config_file) as f: + config = json.load(f) + logger.info("Loaded sparse attention config from %s", config_file) + + return config + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: list[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (list[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, list[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. + + Returns: + str: The path to the downloaded model weights. + """ + local_only = huggingface_hub.constants.HF_HUB_OFFLINE + if not local_only: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + start_time = time.perf_counter() + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=local_only, + ) + time_taken = time.perf_counter() - start_time + if time_taken > 0.5: + logger.info("Time spent downloading weights for %s: %.6f seconds", + model_name_or_path, time_taken) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + index_file (str): The safetensors index file name + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=index_file, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have index_file. + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", index_file) + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", index_file) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files(hf_weights_files: list[str], + hf_folder: str, + index_file: str) -> list[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [ + f for f in hf_weights_files if f in weight_files_in_index + ] + return hf_weights_files + + +def filter_files_not_needed_for_inference( + hf_weights_files: list[str]) -> list[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def enable_tqdm(use_tqdm_on_load: bool): + return use_tqdm_on_load and (not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0) + + +def np_cache_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str], + hf_folder: str, + hf_weights_files: list[str], + use_tqdm_on_load: bool, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names: list[str] = [] + for bin_file in tqdm( + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, + map_location="cpu", + weights_only=True) + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file) as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + total_count = 0 + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + total_count += len(f.keys()) + current_count = 0 + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + current_count += 1 + param = f.get_tensor(name) + param.current_count = current_count + param.total_count = total_count + yield name, param + + +def runai_safetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + with SafetensorsStreamer() as streamer: + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors using Runai Model Streamer", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + streamer.stream_file(st_file) + yield from streamer.get_tensors() + + +def fastsafetensors_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files + using fastsafetensor library.""" + if torch.distributed.is_initialized(): + pg = torch.distributed.group.WORLD + else: + pg = SingleGroup() + + device = torch.device(f'cuda:{pg.rank()}') + weight_files_sub_lists = [ + hf_weights_files[i:i + pg.size()] + for i in range(0, len(hf_weights_files), pg.size()) + ] + + for f_list in tqdm( + weight_files_sub_lists, + desc="Loading safetensors using Fastsafetensor loader", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + loader = SafeTensorsFileLoader(pg, device) + rank_file_map = {i: [f] for i, f in enumerate(f_list)} + loader.add_filenames(rank_file_map) + try: + fb = loader.copy_files_to_device() + try: + keys = list(fb.key_to_rank_lidx.keys()) + for k in keys: + t = fb.get_tensor(k) + yield k, t + finally: + fb.close() + finally: + loader.close() + + +def pt_weights_iterator( + hf_weights_files: list[str], + use_tqdm_on_load: bool, + pt_load_map_location: Union[str, dict[str, str]] = "cpu", +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + total_count = 0 + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True) + total_count += len(state) + del state + + current_count = 0 + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location=pt_load_map_location, weights_only=True) + for name, param in state.items(): + current_count += 1 + param.current_count = current_count + param.total_count = total_count + yield name, param + del state + + +def get_gguf_extra_tensor_names( + gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: + reader = gguf.GGUFReader(gguf_file) + expected_gguf_keys = set(gguf_to_hf_name_map.keys()) + exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) + extra_keys = expected_gguf_keys - exact_gguf_keys + return [gguf_to_hf_name_map[key] for key in extra_keys] + + +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + yield name, param + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})") + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def row_parallel_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Load weights that are row-parallelized.""" + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + + if shard_dim is not None: + shard_size = param.data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], + torch.Tensor]) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + if current_platform.is_tpu(): + generator = torch.Generator(device="cpu") + generator.manual_seed(seed) + # Note: The param.uniform_ function cannot be used in this + # context because it demands more TPU HBM than directly copying + # from a CPU tensor. + # Note: We avoid using torch.rank_like as it doesn't currently + # support the generator argument. + param.copy_((high - low) * + torch.rand(param.shape, + generator=generator, + dtype=param.dtype, + layout=param.layout, + requires_grad=param.requires_grad, + device="cpu") + low) + torch._sync(param) + continue + + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high, + generator=generator).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high, generator=generator) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale") + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + logger.warning_once( + "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501 + name, + remapped_name, + ) + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [ + ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" + ] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + if any(mo_scale_name in name + for mo_scale_name in modelopt_scale_names): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}") + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + logger.warning_once( + "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 + scale_name, + name, + remapped_name, + scale_name, + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py new file mode 100644 index 0000000..d3ee687 --- /dev/null +++ b/vllm/model_executor/models/__init__.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, + SupportsPP, SupportsTranscription, SupportsV0Only, + has_inner_state, supports_lora, supports_multimodal, + supports_pp, supports_transcription, supports_v0_only) +from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, + is_pooling_model, is_text_generation_model) +from .registry import ModelRegistry + +__all__ = [ + "ModelRegistry", + "VllmModelForPooling", + "is_pooling_model", + "VllmModelForTextGeneration", + "is_text_generation_model", + "HasInnerState", + "has_inner_state", + "SupportsLoRA", + "supports_lora", + "SupportsMultiModal", + "supports_multimodal", + "SupportsPP", + "supports_pp", + "SupportsTranscription", + "supports_transcription", + "SupportsV0Only", + "supports_v0_only", +] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py new file mode 100644 index 0000000..78d86f6 --- /dev/null +++ b/vllm/model_executor/models/adapters.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast + +import torch +import torch.nn as nn + +from vllm.model_executor.models.config import VerifyAndUpdateConfig + +from .interfaces_base import VllmModelForPooling, is_pooling_model + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import PoolingType + +_T = TypeVar("_T", bound=type[nn.Module]) + +_GENERATE_SUFFIXES = [ + "ForCausalLM", + "ForConditionalGeneration", + "ChatModel", + "LMHeadModel", +] + + +def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: + model_name = orig_model_name + + for generate_suffix in _GENERATE_SUFFIXES: + model_name = model_name.removesuffix(generate_suffix) + + return model_name + pooling_suffix + + +def _create_pooling_model_cls( + orig_cls: _T, + *, + default_pooling_type: "PoolingType", + default_normalize: bool, + default_softmax: bool, +) -> _T: + # Lazy import + from vllm.model_executor.layers.pooler import Pooler, PoolerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + + from .utils import AutoWeightsLoader, WeightsMapper + + class ModelForPooling(orig_cls, VllmModelForPooling): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + # These are not used in pooling models + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # TODO: Support uninitialized params tracking + + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + if hasattr(self, "model") and hasattr(self.model, "load_weights"): + # Whether only `self.model` contains parameters + model_is_only_param = all( + name == "model" or next(child.parameters(), None) is None + for name, child in self.named_children()) + + if model_is_only_param: + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + loaded_params = self.model.load_weights(weights) + loaded_params = {f"model.{name}" for name in loaded_params} + return loaded_params + + # For most other models + if hasattr(orig_cls, "load_weights"): + return orig_cls.load_weights(self, weights) # type: ignore + # Fallback + else: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + return ModelForPooling # type: ignore + + +def as_embedding_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support embeddings. + + By default, the embeddings of the whole prompt are extracted from the + normalized hidden state corresponding to the last token. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing embedding models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.pooler import PoolingType + + ModelForEmbedding = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.LAST, + default_normalize=True, + default_softmax=False, + ) + ModelForEmbedding.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForEmbedding") + + return ModelForEmbedding # type: ignore + + +def as_seq_cls_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support classify and score tasks. + + By default, the class probabilities are extracted from the softmaxed + hidden state corresponding to the last token. + + Note: + We assume that the classification head is a single linear layer + stored as the attribute `score` of the top-level model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing classification models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.linear import RowParallelLinear + from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType + from vllm.model_executor.models.interfaces import SupportsCrossEncoding + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.sequence import IntermediateTensors + + from .utils import maybe_prefix + + ModelForPooling = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.LAST, + default_normalize=False, + default_softmax=True, + ) + + class ModelForSequenceClassification(ModelForPooling, + SupportsCrossEncoding): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.vllm_config = vllm_config + self.task = vllm_config.model_config.task + self.pooling_type = ( + vllm_config.model_config.pooler_config.pooling_type) + + self.score = RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix( + prefix, "score")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return super().forward(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def pooler( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + + def get_logits(hidden_states): + if isinstance(hidden_states, list): + logits = [self.score(state)[0] for state in hidden_states] + else: + logits, _ = self.score(hidden_states) + return logits + + if self.pooling_type == PoolingType.ALL: + logits = get_logits(hidden_states) + return self._pooler(logits, pooling_metadata) + else: + hidden_states = self._pooler.extract_states( + hidden_states, pooling_metadata) + logits = get_logits(hidden_states) + pooled_data = self._pooler.head(logits, pooling_metadata) + + pooled_outputs = [ + self._pooler.build_output(data) for data in pooled_data + ] + return PoolerOutput(outputs=pooled_outputs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + tokens = getattr(self.config, "classifier_from_token", None) + method = getattr(self.config, "method", None) + + if tokens is None and method is None: + return super().load_weights(weights) + else: + # Online convert ForCausalLM into + # ForSequenceClassification model. + return seq_cls_model_loader(self, weights) + + + ModelForSequenceClassification.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForSequenceClassification") + + return ModelForSequenceClassification # type: ignore + + +def as_reward_model(cls: _T) -> _T: + """ + Subclass an existing vLLM model to support reward modeling. + + By default, we return the hidden states of each token directly. + + Note: + We assume that no extra layers are added to the original model; + please implement your own model if this is not the case. + """ + # Avoid modifying existing reward models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.model_executor.layers.pooler import PoolingType + + ModelForReward = _create_pooling_model_cls( + cls, + default_pooling_type=PoolingType.ALL, + default_normalize=False, + default_softmax=False, + ) + + ModelForReward.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForReward") + + return ModelForReward # type: ignore + + +class SequenceClassificationConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + method = getattr(config, "method", None) + tokens = getattr(config, "classifier_from_token", None) + + if method is None: + return + + assert tokens is not None + assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" + + if method == "from_2_way_softmax": + assert len(tokens) == 2 + config.num_labels = 1 + else: + config.num_labels = len(tokens) + + +def load_weights_using_from_2_way_softmax( + model, weights: Iterable[tuple[str, torch.Tensor]]): + # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 + from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) + from vllm.model_executor.models.utils import AutoWeightsLoader + + model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) + tokens = cast(list[int], tokens) + assert len(tokens) == 2 + + device = model.score.weight.device + + if model.config.tie_word_embeddings: + model.lm_head = model.model.embed_tokens + else: + model.lm_head = ParallelLMHead(model.config.vocab_size, + model.config.hidden_size, + quant_config=model.quant_config) + + loader = AutoWeightsLoader(model) + loaded_weights = loader.load_weights(weights) + + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer(model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) + + false_id = tokenizer.convert_tokens_to_ids(tokens[0]) + true_id = tokenizer.convert_tokens_to_ids(tokens[1]) + weight = model.lm_head.weight.data[true_id].to(device).to( + torch.float32) - model.lm_head.weight.data[false_id].to(device).to( + torch.float32) + model.score.weight.data.copy_(weight) + + del model.lm_head + loaded_weights.add("score.weight") + loaded_weights.discard("lm_head.weight") + return loaded_weights + + +SEQ_CLS_LOAD_METHODS = { + "from_2_way_softmax": load_weights_using_from_2_way_softmax, +} + + +def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): + # Online convert ForCausalLM into ForSequenceClassification model. + # - from_2_way_softmax: + # - Qwen3ForCausalLM + # - Qwen3-Reranker + # - Qwen2ForCausalLM + # - mxbai-rerank-v2 + + config = model.vllm_config.model_config.hf_config + method = getattr(config, "method", None) + assert method in SEQ_CLS_LOAD_METHODS, f"method {method} not supported" + return SEQ_CLS_LOAD_METHODS[method](model, weights) diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py new file mode 100644 index 0000000..b13d863 --- /dev/null +++ b/vllm/model_executor/models/aimv2.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# A modified implementation of the AIMv2 Transformer +# inserted here also the image tokenizer used by Ovis2 +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.utils import divide +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.transformers_utils.configs.ovis import AIMv2Config + + +class AIMv2SwiGLUFFN(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.hidden_size + bias = config.use_bias + + self.fc13 = MergedColumnParallelLinear( + in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + ) + self.fc2 = RowParallelLinear( + input_size=hidden_features, + output_size=in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc13(x) + x = self.act_fn(x) + x, _ = self.fc2(x) + return x + + +class AIMv2PatchEmbed(nn.Module): + + def __init__(self, config: AIMv2Config): + super().__init__() + self.proj = nn.Conv2d( + config.num_channels, + config.hidden_size, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm.forward_native(x) + return x + + +class AIMv2ViTPreprocessor(nn.Module): + + def __init__(self, config: AIMv2Config): + super().__init__() + num_patches = (config.image_size // config.patch_size)**2 + + self.patchifier = AIMv2PatchEmbed(config) + self.pos_embed = nn.Parameter( + torch.zeros((1, num_patches, config.hidden_size))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + tokens = self.patchifier(x) + _, N, _ = tokens.shape + pos_embed = self.pos_embed.to(tokens.device) + tokens = tokens + pos_embed[:, :N] + return tokens + + +class AIMv2Attention(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + self.proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + x = self.attn(q, k, v) + x, _ = self.proj(x) + return x + + +class AIMv2Block(nn.Module): + + def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, + prefix: str): + super().__init__() + self.attn = AIMv2Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AIMv2SwiGLUFFN(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm_1.forward_native(x)) + x = x + self.mlp(self.norm_2.forward_native(x)) + return x + + +class AIMv2Transformer(nn.Module): + + def __init__( + self, + config: AIMv2Config, + quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ): + super().__init__() + + self.blocks = nn.ModuleList([ + AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") + for i in range(config.num_hidden_layers) + ]) + if require_post_norm: + self.post_trunk_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.post_trunk_norm = None + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + # they take the -1 as the ref embeddings, like a clip skip + for block in self.blocks: + tokens = block(tokens) + if self.post_trunk_norm is not None: + tokens = self.post_trunk_norm(tokens) + return tokens + + +class AIMv2Model(torch.nn.Module): + + def __init__(self, + config: AIMv2Config, + quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, + prefix: str = ""): + super().__init__() + self.preprocessor = AIMv2ViTPreprocessor(config) + self.trunk = AIMv2Transformer(config, + quant_config=quant_config, + require_post_norm=require_post_norm, + prefix=f"{prefix}.trunk") + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + + x = self.preprocessor(pixel_values) + x = self.trunk(x) + + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".fc13", ".fc1", 0), + (".fc13", ".fc3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if (name.startswith("trunk.post_trunk_norm") + and self.trunk.post_trunk_norm is None): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000..4693c94 --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,559 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Snowflake Arctic model.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.arctic import ArcticConfig + +from .interfaces import SupportsPP, SupportsQuant +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = ""): + super().__init__() + + layer_id = extract_layer_index(prefix) + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.mlp") + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config, + prefix=f"{prefix}.gate") + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.block_sparse_moe = ArcticMoE( + config, + quant_config=quant_config, + reduce_results=(not self.use_residual), + prefix=f"{prefix}.block_sparse_moe", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + is_residual_mlp=True, + reduce_results=False, + prefix=f"{prefix}.residual_mlp") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +@support_torch_compile +class ArcticModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: ArcticDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.model = ArcticModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping: list[tuple[str, str, int]] = [] + expert_params_mapping: list[tuple[str, str, int]] = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py new file mode 100644 index 0000000..8ae1680 --- /dev/null +++ b/vllm/model_executor/models/aria.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping, Sequence +from typing import Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import AriaConfig, AriaTextConfig, BatchFeature +from transformers.models.aria.modeling_aria import AriaCrossAttention +from transformers.models.aria.processing_aria import AriaProcessor + +from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +# yapf: disable +from .idefics2_vision_model import Idefics2VisionConfig +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant +from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + is_pp_missing_parameter, maybe_prefix, + merge_multimodal_embeddings) + + +class AriaImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + pixel_mask: Optional[torch.Tensor] + """ + Shape: + pixel_values: `(batch_size * num_images, num_channels, height, width)` + pixel_mask: `(batch_size * num_images, height, width)` + """ + + +class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config=quant_config, prefix=prefix) + # Unlike Idefics3VisionTransformer which uses LayerNorm after the + # final layer, Aria omits this normalization, so we replace it with an + # Identity layer + self.post_layernorm = nn.Identity() + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + # NOTE: post_layernorm is not used in Aria + if "post_layernorm" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class AriaProjectorMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + output_dim: int, + ) -> None: + super().__init__() + + self.linear_in = ColumnParallelLinear(in_features, + hidden_features, + bias=False) + self.linear_out = RowParallelLinear(hidden_features, + output_dim, + bias=False) + self.act = get_act_fn("gelu_new") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_out(hidden_states) + return hidden_states + + +class AriaProjector(nn.Module): + """ + A projection module with one cross attention layer and one FFN layer, which + projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding + query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes + based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__(self, config: AriaConfig) -> None: + super().__init__() + + self.patch_to_query_dict = config.projector_patch_to_query_dict + self.in_features = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_attention_heads + self.kv_dim = config.vision_config.hidden_size + self.hidden_features = config.text_config.hidden_size + self.output_dim = config.text_config.hidden_size + + self.query = nn.Parameter( + torch.empty(config.max_value_projector_patch_to_query_dict, + self.in_features)) + + self.cross_attn = AriaCrossAttention(config) + + self.layer_norm = nn.LayerNorm(self.in_features) + self.feed_forward = AriaProjectorMLP(self.in_features, + self.hidden_features, + self.output_dim) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, num_patches = x.shape[0], x.shape[1] + + if num_patches not in self.patch_to_query_dict: + raise KeyError(f"Number of patches {num_patches} not found in " + "patch_to_query_dict amongst possible values " + f"{self.patch_to_query_dict.keys()}.") + + query_num = self.patch_to_query_dict[num_patches] + + queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1) + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + + out = self.feed_forward(self.layer_norm(attention_out)) + + return out + + +class AriaFusedMoE(FusedMoE): + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + shard_id: str) -> None: + # Override the weight_loader to handle the expert weights in the Aria + # model, which are already packed with experts, and merge the gate and + # up weights for each expert. + # Note: Loading expert weights with quantization is not supported + tp_rank = get_tensor_model_parallel_rank() + if shard_id == 'w13': + # the shape of loaded_weight is + # (num_experts, hidden_size, 2 * moe_intermediate_size) + if self.tp_size > 1: + up, gate = loaded_weight.chunk(2, dim=-1) + up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] + gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] + up_and_gate = torch.cat([up_current_rank, gate_current_rank], + dim=-1).transpose(1, 2) + param.data.copy_(up_and_gate) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + elif shard_id == 'w2': + # the shape of loaded_weight is + # (num_experts, moe_intermediate_size, hidden_size) + if self.tp_size > 1: + down_current_rank = loaded_weight.chunk(self.tp_size, + dim=1)[tp_rank] + param.data.copy_(down_current_rank.transpose(1, 2)) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + + +class AriaTextMoELayer(nn.Module): + """ + Mixture of Experts (MoE) Layer for the AriaMoE model. + + This layer implements the MoE mechanism, which routes input tokens to + different experts based on a routing algorithm, processes them through the + experts, and then combines the outputs. + """ + + def __init__( + self, + config: AriaTextConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.router_weight = nn.Parameter( + torch.empty( + (self.config.moe_num_experts, self.config.hidden_size))) + + self.experts = AriaFusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_topk, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.experts", + ) + self.shared_experts = LlamaMLP( + config.hidden_size, + config.intermediate_size * config.moe_num_shared_experts, + "silu", + quant_config=quant_config, + bias=config.mlp_bias, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, + sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + """ + + router_output = torch.nn.functional.linear(hidden_states, + self.router_weight) + + hidden_states_copy = hidden_states.clone() + # NOTE: hidden_states will be modified inplace by `FusedMoE` + sparse_expert_output = self.experts(hidden_states, router_output) + shared_expert_output = self.shared_experts(hidden_states_copy) + + return sparse_expert_output + shared_expert_output + + +class AriaTextDecoderLayer(LlamaDecoderLayer): + """ + Custom Decoder Layer for the AriaMoE model which modifies the standard + `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of + Experts (MoE) Layer. + """ + + def __init__( + self, + config: AriaTextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config, prefix) + self.mlp = AriaTextMoELayer(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + +class AriaTextModel(LlamaModel, SupportsQuant): + """ + Custom LlamaModel for the AriaMoE model which modifies the standard + LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. + """ + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + "experts.w13_weight": ["experts.fc1.weight"], + "experts.w2_weight": ["experts.fc2.weight"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=AriaTextDecoderLayer) + + # Adapted from LlamaModel.load_weights with the modification of adding + # the expert weights mapping to `stacked_params_mapping` + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ("experts.w13_weight", "experts.fc1.weight", 'w13'), + ("experts.w2_weight", "experts.fc2.weight", 'w2'), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class AriaProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(AriaConfig) + + def get_vision_config(self): + return self.get_hf_config().vision_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(AriaProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return max(hf_config.projector_patch_to_query_dict.values()) + + +class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token: str = processor.tokenizer.image_token # type: ignore + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + vision_config = self.info.get_vision_config() + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + +class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_mask=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index + + num_image_tokens = self.info.get_num_image_tokens() + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=[image_token_id] * num_image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, + info=AriaProcessingInfo, + dummy_inputs=AriaDummyInputsBuilder) +class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): + """ + Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language + model to perform tasks that involve both image and text inputs. + """ + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + # mapping for original checkpoint + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "router.weight": "router_weight", + }, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|fim_prefix|><|img|><|fim_suffix|>" + + raise ValueError("Only image modality is supported") + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.vision_tower = AriaVisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=f"{prefix}.vision_tower", + ) + self.multi_modal_projector = AriaProjector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AriaTextModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model.model"), + ) + self.pad_token_id = (self.config.pad_token_id + if self.config.pad_token_id is not None else -1) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.vocab_size, logit_scale) + + def _validate_image_sizes( + self, images: list[torch.Tensor]) -> list[torch.Tensor]: + if not all(img.shape == images[0].shape for img in images): + raise ValueError("All images must be the same size") + return images + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + pixel_mask = kwargs.pop("pixel_mask", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = self._validate_image_sizes(pixel_values) + pixel_values = flatten_bn(pixel_values, concat=True) + + if pixel_mask is not None: + if not isinstance(pixel_mask, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel mask. " + f"Got type: {type(pixel_mask)}") + + pixel_mask = flatten_bn(pixel_mask, concat=True) + + return AriaImagePixelInputs( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + ) + + def _create_patch_attention_mask( + self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_tower.config.patch_size, + step=self.vision_tower.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _process_image_input( + self, image_input: AriaImagePixelInputs + ) -> tuple[torch.Tensor, torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input['pixel_values'] + pixel_mask = image_input['pixel_mask'] + + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + image_outputs = self.vision_tower( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + image_attn_mask = None + if patch_attention_mask is not None: + flattened_mask = patch_attention_mask.flatten(1) + image_attn_mask = torch.logical_not(flattened_mask) + + return self.multi_modal_projector(image_outputs, image_attn_mask) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + multimodal_embeddings = self._process_image_input(image_input) + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py new file mode 100644 index 0000000..45dd660 --- /dev/null +++ b/vllm/model_executor/models/aya_vision.py @@ -0,0 +1,486 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union, cast + +import torch +from torch import nn +from transformers import BatchFeature, GotOcr2ImageProcessor +from transformers.activations import ACT2FN +from transformers.image_processing_utils import get_size_dict +from transformers.models.aya_vision import AyaVisionConfig +from transformers.models.aya_vision.processing_aya_vision import ( + AyaVisionProcessor) +from transformers.models.got_ocr2.image_processing_got_ocr2 import ( + get_optimal_tiled_canvas) + +from vllm.config import VllmConfig +from vllm.jsontree import json_map_leaves +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + + +class AyaVisionImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(num_patches_total, num_channels, height, width)` + + `num_patches_total` is the total number of patches over each image over each + prompt in the batch. + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class AyaVisionMultiModalProjector(nn.Module): + + def __init__(self, config: AyaVisionConfig): + super().__init__() + self.config = config + self.downsample_factor = config.downsample_factor + self.alignment_intermediate_size = getattr( + config, "alignment_intermediate_size", + config.text_config.hidden_size) + self.layernorm = nn.LayerNorm(config.vision_config.hidden_size * + (config.downsample_factor**2), + eps=config.adapter_layer_norm_eps) + + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * (config.downsample_factor**2), + self.alignment_intermediate_size, + bias=True, + ) + + self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation + # For SwiGLU, project down to half size since we split intermediate dim + self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + image_features = self.pixel_shuffle(image_features) + image_features = self.layernorm(image_features) + hidden_states = self.linear_1(image_features) + + # Split along last dimension and apply SwiGLU + x, gate = hidden_states.chunk(2, dim=-1) + hidden_states = self.act(gate) * x + + hidden_states = self.linear_2(hidden_states) + return hidden_states + + def pixel_shuffle(self, + image_features: torch.Tensor) -> torch.Tensor: # B, S, D + batch_size, seq_length, _ = image_features.shape + height = width = int(seq_length**0.5) + image_features = image_features.reshape(image_features.shape[0], width, + height, -1) + channels = image_features.shape[-1] + image_features = image_features.reshape( + batch_size, width, int(height / self.downsample_factor), + int(channels * self.downsample_factor)) + image_features = image_features.permute(0, 2, 1, 3) + image_features = image_features.reshape( + batch_size, int(height / self.downsample_factor), + int(width / self.downsample_factor), -1) + image_features = image_features.permute(0, 2, 1, 3) + return image_features + + +class AyaVisionProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> AyaVisionConfig: + return self.ctx.get_hf_config(AyaVisionConfig) + + def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor: + processor = self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs) + + # Temporary workaround since this processor has multiple image tokens + # See https://github.com/huggingface/transformers/issues/38350 + processor._check_special_mm_tokens = lambda *args, **kwargs: None + + return processor + + def get_image_processor(self) -> GotOcr2ImageProcessor: + return self.get_hf_processor().image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + height = image_processor.size['height'] + width = image_processor.size['width'] + max_patches = image_processor.max_patches + return ImageSize(height=height * max_patches, + width=width * max_patches) + + def get_num_patches(self, *, image_width: int, image_height: int, + size: dict, min_patches: int, max_patches: int) -> int: + """ + Calculate the number of patches needed for a given image based on size + constraints. This method replicates and adjusts the logic from: + transformers/models/got_ocr2/image_processing_got_ocr2 + """ + size = get_size_dict(size, default_to_square=False) + num_columns, num_rows = get_optimal_tiled_canvas( + (image_height, image_width), (size["height"], size["width"]), + min_patches, max_patches) + num_blocks = num_columns * num_rows + return num_blocks if num_blocks == 1 else num_blocks + 1 + + +class AyaVisionDummyInputsBuilder( + BaseDummyInputsBuilder[AyaVisionProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + image_size = \ + self.info.get_image_size_with_most_features() + + return { + "image": + self._get_dummy_images(width=image_size.width, + height=image_size.height, + num_images=num_images) + } + + +class AyaVisionMultiModalProcessor( + BaseMultiModalProcessor[AyaVisionProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + hf_processor = self.info.get_hf_processor(**mm_kwargs) + image_processor = hf_processor.image_processor + + # HF processor pops the `num_patches` kwarg, which is needed by vLLM + if (images := mm_data.get("images")) is not None: + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) + for i in range(len(parsed_images)) + ] + + num_patches = [ + self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + size=image_processor.size, + min_patches=image_processor.min_patches, + max_patches=image_processor.max_patches) + for image_size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.image_token + img_patch_token = hf_processor.img_patch_token + image_processor = hf_processor.image_processor + + def get_replacement(item_idx: int): + images: ImageProcessorItems = mm_items.get("image", + ImageProcessorItems) + image_size: ImageSize = images.get_image_size(item_idx) + num_patches = self.info.get_num_patches( + image_width=image_size.width, + image_height=image_size.height, + size=image_processor.size, + min_patches=image_processor.min_patches, + max_patches=image_processor.max_patches, + ) + repl = hf_processor._prompt_split_image(num_patches=num_patches) + + return PromptUpdateDetails.select_text(repl, img_patch_token) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement, + ) + ] + + +def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int: + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest m + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + + +@MULTIMODAL_REGISTRY.register_processor( + AyaVisionMultiModalProcessor, + info=AyaVisionProcessingInfo, + dummy_inputs=AyaVisionDummyInputsBuilder) +class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: AyaVisionConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + num_hidden_layers = _get_num_hidden_layers(config) + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=maybe_prefix(prefix, "vision_model")) + self.vocab_size = config.text_config.vocab_size + self.multi_modal_projector = AyaVisionMultiModalProjector(config) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "model"), + # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm + architectures=["Cohere2ForCausalLM"]) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + **kwargs) -> torch.Tensor: + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype), + **kwargs) + + def select_features(leaf: torch.Tensor): + return self._select_image_features( + leaf, + strategy=self.config.vision_feature_select_strategy, + ) + + return cast( + Union[torch.Tensor, tuple[torch.Tensor, ...]], + json_map_leaves(select_features, image_features), + ) + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _process_image_input(self, image_input: AyaVisionImagePixelInputs, + **kwargs) -> list[torch.Tensor]: + assert self.vision_tower is not None + pixel_values = image_input["pixel_values"] + num_patches = image_input["num_patches"] + image_features = self._image_pixels_to_features( + self.vision_tower, pixel_values=pixel_values) + image_embeds = self.multi_modal_projector(image_features) + return [ + e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) + ] + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + if d.shape != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_dims}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + num_patches = kwargs.pop("num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Aya Vision does not support image_embeds." + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + if num_patches is not None and not isinstance(num_patches, + (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + num_patches = flatten_bn(num_patches, concat=True) + + return AyaVisionImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + num_patches=num_patches, + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input, **kwargs) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=self.config.image_token_index, + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py new file mode 100644 index 0000000..ea41222 --- /dev/null +++ b/vllm/model_executor/models/baichuan.py @@ -0,0 +1,583 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BaiChuan model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +import os +import re +import vllm.envs as envs + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +from vllm import _custom_ops as ops +from vllm.model_executor.utils import pad_weight, gemm_bank_conf + + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BaiChuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class BaiChuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.position_embedding = position_embedding + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # pylint: disable=invalid-name + self.W_pack = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + # Create the alibi slopes and slice them. + if self.position_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn") + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + # if os.environ.get('FA_PAD') == '1' and self.quant_method is None: + # qkv = qkv[...,:-32] + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.position_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class BaiChuanDecoderLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = BaiChuanAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + position_embedding=position_embedding, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = BaiChuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class BaiChuanModel(nn.Module): + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + position_embedding: str = "ROPE", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BaiChuanDecoderLayer(config, + position_embedding, + cache_config, + quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' + self.use_fa_pad = os.environ.get('FA_PAD') == '1' + self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None : + lay_key_words = [ + "self_attn.W_pack.weight", + "self_attn.o_proj.weight", + "mlp.gate_up_proj.weight", + "mlp.down_proj.weight", + "lm_head.weight" + ] + combined_words = "|".join(lay_key_words) + + # lay_qkv_words = ["self_attn.W_pack.weight"] + # qkv_words = "|".join(lay_qkv_words) + + for layername in loaded_params: + weight = params_dict[layername] + if "lm_head.weight" in layername and weight.shape[1] >= 4096: + lay_key_words.append("lm_head.weight") + combined_words = "|".join(lay_key_words) + os.environ['LM_NN'] = '1' + else: + os.environ['LM_NN'] = '0' + matches = re.findall(combined_words, layername) + if matches: + # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + # if self.use_fa_pad and (re.findall(qkv_words, layername)): + # if not gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1], -1) + else: + os.environ['LM_NN'] = '0' + os.environ['LLAMA_NN'] = '0' + + # if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ: + # lay_key_words = [ + # "self_attn.W_pack.qweight", + # "self_attn.o_proj.qweight", + # "mlp.gate_up_proj.qweight", + # "mlp.down_proj.qweight" + # ] + # combined_words = "|".join(lay_key_words) + + # for layername in loaded_params: + # weight = params_dict[layername] + + # matches = re.findall(combined_words, layername) + # if matches: + # qweight =params_dict[layername] + # qzeros=params_dict[layername.replace("qweight", "qzeros")] + # scales=params_dict[layername.replace("qweight", "scales")] + # zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")] + + # group_size= self.quant_config.group_size + + # dim_n = scales.data.shape[1] + # dim_k = qweight.data.shape[0] + # pad_group=2 + + # _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) + + # sz = ops.sz_permute(_sz).reshape(-1,dim_n) + + # zeros_and_scalse.data.copy_(sz) + # qweight.data.copy_(_qw) + + # #reshape + # zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size] + # qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8] + + # if dim_k % 4096==0 and self.use_awq_pad: + # zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() + # zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() + # qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() + # qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous() + return loaded_params + + +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + SupportsQuant): + packed_modules_mapping = { + "W_pack": ["W_pack"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + position_embedding: str = "ROPE", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.model = BaiChuanModel(vllm_config=vllm_config, + prefix=prefix, + position_embedding=position_embedding) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.lm_head.weight.weight_loader = self.lm_head_weight_loader + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def lm_head_weight_loader(self, param: nn.Parameter, + loaded_weight: torch.Tensor): + # Unlike Baichuan, Baichuan2 normalizes the head weights. + # Refer to: + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 + # Distinguish between Baichuan and Baichuan2 by checking the + # vocab size. This is suggested by + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 + is_baichuan2 = self.config.vocab_size == 125696 + if is_baichuan2: + loaded_weight = torch.nn.functional.normalize(loaded_weight) + if self.tp_size > 1: + row_parallel_weight_loader(param, loaded_weight) + else: + default_weight_loader(param, loaded_weight) + + +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 13B and Baichuan2 7B/13B. + NOTE: the class name has a lower case 'c'. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if config.hidden_size == 4096: # baichuan2 7b + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ROPE") + else: # baichuan 13b, baichuan2 13b + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ALIBI") + + +class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 7B. + NOTE: the class name has an upper case 'C'. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + position_embedding="ROPE") \ No newline at end of file diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py new file mode 100644 index 0000000..d743c52 --- /dev/null +++ b/vllm/model_executor/models/bamba.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Bamba model.""" +# Added by the IBM Team, 2024 +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import BambaConfig + +from vllm import envs +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class BambaMLP(nn.Module): + + def __init__( + self, + config: BambaConfig, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BambaMixerDecoderLayer(nn.Module): + + def __init__(self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, mamba_cache_params, + mamba2_metadata) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class BambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: BambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + self.feed_forward = BambaMLP(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": BambaAttentionDecoderLayer, + "mamba": BambaMixerDecoderLayer +} + + +class BambaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: BambaConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + attn_metadata = get_forward_context().attn_metadata + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + residual = None + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + if isinstance(layer, BambaAttentionDecoderLayer): + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance(layer, + BambaMixerDecoderLayer) and mamba_cache_params: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["up_proj", "down_proj"] + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert not cache_config.enable_prefix_caching, \ + "Bamba currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = BambaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = \ + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba + ) + + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> tuple[tuple[int, int], tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000..a0ec126 --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,938 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" +import math +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import BartConfig +from transformers.utils import logging + +from vllm.attention import Attention, AttentionType +from vllm.config import CacheConfig, LoRAConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsQuant, SupportsV0Only +from .utils import maybe_prefix + +logger = logging.get_logger(__name__) + + +def get_bsz_seq_len(input_ids): + shp = input_ids.shape + ndim = len(shp) + if ndim == 1: + return 1, input_ids.numel() + else: + return shp[:2] + + +class BartLearnedPositionalEmbedding(VocabParallelEmbedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is + # specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. + # Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + positions: torch.Tensor, + ) -> torch.Tensor: + """`input_ids' shape is expected to be [bsz x seqlen].""" + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(VocabParallelEmbedding): + """ + This module overrides VocabParallelEmbedding's + forward by multiplying with embeddings scale. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale + + +class BartParallelLMHead(ParallelLMHead): + """ + This module overrides ParallelLMHead's + forward by dividing by embeddings scale, + yielding effectively the inverse of + BartScaledWordEmbedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) / self.embed_scale + + +class BartEncoderAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = self.num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + return output + + +class BartDecoderSelfAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = self.num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + return output + + +class BartCrossAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + # TP sharding sizes is accounted for within "*Parallel" layers. + self.qkv_proj = QKVCrossParallelLinear(self.d_model, + self.d_model // + self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias, + quant_config=quant_config) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = self.num_heads # No GQA in bart + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + return output + + +class BartEncoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = get_act_fn(config.activation_function) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu") + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + return hidden_states + + +class BartDecoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartDecoderSelfAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.activation_fn = get_act_fn(config.activation_function) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + ''' + afeldman-nm: personally I would call this "cross-attention", + however I left the name as "encoder_attn" to maintain consistency + with the name of the pretrained weights. + ''' + self.encoder_attn = BartCrossAttention( + self.embed_dim, + config.decoder_attention_heads, + config=config, + prefix=f"{prefix}.encoder_attn", + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_hidden_states + torch.Tensor of *decoder* input embeddings. + encoder_hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Decoder layer output torch.Tensor + """ + residual = decoder_hidden_states + + # Self Attention + hidden_states = self.self_attn(hidden_states=decoder_hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + + residual = hidden_states + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +class BartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([ + BartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions(positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states=hidden_states) + + return hidden_states + + +class BartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [BartDecoderLayer(config,cache_config,quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + Returns: + Decoder output torch.Tensor + """ + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] + + # embed positions + embed_pos = self.embed_positions(decoder_positions) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for decoder_layer in self.layers: + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + return hidden_states + + +class BartModel(nn.Module, SupportsQuant): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = BartEncoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = BartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + base_model_prefix = "model" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + super().__init__() + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + # currently all existing BART models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.config = config + self.model = BartModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + stacked_params_mapping = { + "q_proj": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "k_proj": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "v_proj": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + model_params_dict = dict(self.model.named_parameters()) + top_params_dict = dict(self.named_parameters()) + + weights_tuple_list = list(weights) + + shared_embedding_weight = None + shared_embedding_shard_id = None + + for name, loaded_weight in weights_tuple_list: + + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + if ('shared.weight' in name + or 'encoder.embed_tokens.weight' in name + or 'decoder.embed_tokens.weight' in name + or 'lm_head.weight' in name): + assert shared_embedding_weight is None, ( + "Conflicting embedding weights.") + shared_embedding_weight = loaded_weight + shared_embedding_shard_id = shard_id + else: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + # Assign shared weight values + encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", + default_weight_loader) + + lm_head_in_param = top_params_dict['lm_head.weight'] + lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", + default_weight_loader) + + assert shared_embedding_weight is not None + + if shared_embedding_shard_id: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, + shared_embedding_shard_id) + else: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py new file mode 100644 index 0000000..6e955e1 --- /dev/null +++ b/vllm/model_executor/models/bert.py @@ -0,0 +1,513 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import BertConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, PoolerConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, + PoolingType) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only +from .utils import WeightsMapper, maybe_prefix + + +class BertEmbedding(nn.Module): + + def __init__(self, config: BertConfig): + + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + def forward( + self, + input_ids: torch.Tensor, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + + # Input embeddings. + inputs_embeds = self.word_embeddings(input_ids) + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[0, :] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@support_torch_compile +class BertEncoder(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.layer = nn.ModuleList([ + BertLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + for layer in self.layer: + hidden_states = layer(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.attention = BertAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + layer_norm_eps=config.layer_norm_eps, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + + self.output = BertOutput(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def forward(self, hidden_states: torch.Tensor): + attn_output = self.attention(hidden_states) + intermediate_output = self.intermediate(attn_output) + output = self.output(intermediate_output, attn_output) + return output + + +class BertAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + layer_norm_eps: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.self = BertSelfAttention(hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.output") + + self.output = BertSelfOutput(hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + self_output = self.self(hidden_states) + return self.output(self_output, hidden_states) + + +class BertSelfAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.attn = Attention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + output = self.attn(q, k, v) + return output + + +class BertSelfOutput(nn.Module): + + def __init__(self, + hidden_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.dense = RowParallelLinear(input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.dense = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.dense = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertModel(nn.Module, SupportsQuant): + packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type = BertEmbedding, + add_pooling_layer: bool = False): + super().__init__() + config = vllm_config.model_config.hf_config + self.embeddings = embedding_class(config) + self.encoder = BertEncoder(vllm_config=vllm_config, + prefix=f"{prefix}.encoder") + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + attn_metadata = get_forward_context().attn_metadata + assert hasattr(attn_metadata, "seq_lens_tensor") + hidden_states = self.embeddings( + input_ids=input_ids, + seq_lens=attn_metadata.seq_lens_tensor, + position_ids=position_ids, + token_type_ids=token_type_ids) + return self.encoder(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.pooler is None and "pooler" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + pooler_config = vllm_config.model_config.pooler_config + self.model = self._build_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self._pooler = self._build_pooler(pooler_config) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + weights = self.hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + self.model.load_weights(weights) + + def _build_model(self, + vllm_config: VllmConfig, + prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, + prefix=prefix, + embedding_class=BertEmbedding) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + return Pooler.from_config_with_defaults(pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) + + +class BertForSequenceClassification(nn.Module, SupportsV0Only, + SupportsCrossEncoding, SupportsQuant): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + add_pooling_layer=True) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = ClassifierPooler(vllm_config.model_config, + self.classifier, self.bert.pooler) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert."):], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.bert(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py new file mode 100644 index 0000000..0b7350f --- /dev/null +++ b/vllm/model_executor/models/bert_with_rope.py @@ -0,0 +1,617 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import (get_act_and_mul_fn, + get_act_fn) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import SupportsV0Only +from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.utils import WeightsMapper +from vllm.sequence import IntermediateTensors + + +class BertWithRopeEmbedding(nn.Module): + + def __init__(self, config: PretrainedConfig): + + super().__init__() + if config.position_embedding_type not in ["rope", "rotary"]: + raise ValueError("Only 'rotary'('rope') position_embedding_type" + + " is supported") + + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + if config.type_vocab_size > 0: + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size) + else: + self.token_type_embeddings = None + + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + if self.token_type_embeddings is not None: + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertWithRopeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = "", + ): + super().__init__() + + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.rotary_emb = get_rope(**rotary_kwargs) + + self.attn = Attention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY) + + self.out_proj = RowParallelLinear(input_size=hidden_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class BertWithRopeGatedMLP(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.act_fn = get_act_and_mul_fn(hidden_act) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class BertWithRopeMLP(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.act_fn = get_act_fn(hidden_act) + self.up_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.up_proj(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class NomicRouter(nn.Module): + + def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int): + super().__init__() + self.moe_top_k = moe_top_k + self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False) + + def forward( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax( + dim=-1, dtype=torch.float32) + top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) + weights = weights.to(x.dtype) + top_weights = top_weights.to(x.dtype) + return weights, top_weights, top_experts # type: ignore + + +class NomicExpertMLP(nn.Module): + + def __init__(self, hidden_size: int, ffn_hidden_size: int, + moe_num_experts: int, ffn_act_fn: str): + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + + self.w1 = nn.Parameter( + torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.w2 = nn.Parameter( + torch.empty(moe_num_experts * ffn_hidden_size, hidden_size)) + self.activation_fn = get_act_fn(ffn_act_fn) + + def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: + expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, + self.hidden_size)[expert_idx] + + x1 = x.matmul(expert_w1.t()) + act_out = self.activation_fn(x1) + x2 = act_out.matmul(expert_w2) + return x2 + + +class NomicExperts(nn.Module): + + def __init__(self, config, hidden_size: int, ffn_hidden_size: int, + moe_num_experts: int): + super().__init__() + self.moe_num_experts = moe_num_experts + + self.mlp = NomicExpertMLP(hidden_size=config.n_embd, + ffn_hidden_size=config.n_inner, + moe_num_experts=moe_num_experts, + ffn_act_fn=config.hidden_act) + self.bias = nn.Parameter(torch.zeros(config.n_embd)) + + def forward(self, x: torch.Tensor, weights: torch.Tensor, + top_weights: torch.Tensor, + top_experts: torch.LongTensor) -> torch.Tensor: + q_len, hidden_size = x.shape + x = x.view(-1, hidden_size) + out = torch.zeros_like(x) + + expert_mask = nn.functional.one_hot( + top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0) + for expert_idx in range(0, self.moe_num_experts): + topk_idx, token_idx = torch.where(expert_mask[expert_idx]) + if token_idx.shape[0] == 0: + continue + + token_list = token_idx.tolist() + topk_list = topk_idx.tolist() + + expert_tokens = x[None, token_list].reshape(-1, hidden_size) + expert_out = self.mlp( + expert_tokens, expert_idx) * top_weights[token_list, topk_list, + None] + + out.index_add_(0, token_idx, expert_out) + + out = out.reshape(q_len, hidden_size) + return out + self.bias + + +class NomicMoELayer(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.router = NomicRouter( + config.n_embd, + moe_num_experts=config.num_experts, + moe_top_k=config.moe_top_k, + ) + + self.experts = NomicExperts( + config, + hidden_size=config.n_embd, + ffn_hidden_size=config.n_inner, + moe_num_experts=config.num_experts, + ) + + def forward(self, x: torch.Tensor): + weights, top_weights, top_experts = self.router(x) + out = self.experts(x, weights, top_weights, top_experts) + return out + + +class BertWithRopeBlock(nn.Module): + + def __init__(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + moe: bool = False, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = ""): + super().__init__() + self.attn = BertWithRopeAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + bias=bias, + rotary_kwargs=rotary_kwargs, + prefix=f"{prefix}.attention") + + if moe: + self.mlp = NomicMoELayer(config=config, ) + else: + if config.hidden_act in ["silu", "geglu"]: + self.mlp = BertWithRopeGatedMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = BertWithRopeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.attn_ln = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp_ln = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): + attn_output = self.attn(positions, hidden_states) + hidden_states = self.attn_ln(hidden_states + attn_output) + mlp_out = self.mlp(hidden_states) + hidden_states = self.mlp_ln(hidden_states + mlp_out) + return hidden_states + + +@support_torch_compile +class BertWithRopeEncoder(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + every_n = getattr(config, "moe_every_n_layers", 0) + self.layers = nn.ModuleList([ + BertWithRopeBlock(config=config, + cache_config=cache_config, + quant_config=quant_config, + bias=bias, + moe=every_n > 0 and (layer_idx % every_n == 1), + rotary_kwargs=rotary_kwargs, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + for layer in self.layers: + hidden_states = layer(positions, hidden_states) + return hidden_states + + +class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = vllm_config.model_config.hf_config + self.embeddings = BertWithRopeEmbedding(self.config) + self.encoder = BertWithRopeEncoder( + vllm_config=vllm_config, + bias=getattr(self.config, "bias", True), + rotary_kwargs=self.config.rotary_kwargs, + prefix=f"{prefix}.encoder") + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids=input_ids, + token_type_ids=token_type_ids) + return self.encoder(positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = self.hf_to_vllm_mapper.apply(weights) + + if self.config.hidden_act in ["silu", "geglu"]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + else: + stacked_params_mapping = [] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "pooler" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class NomicBertModel(BertWithRope): + # for https://huggingface.co/nomic-ai/nomic-bert-2048 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "emb_ln": "embeddings.LayerNorm", + "attn.Wqkv": "attn.qkv_proj", + "norm1": "attn_ln", + "mlp.fc1.": "mlp.up_proj.", + "mlp.fc11": "mlp.up_proj", + "mlp.fc12": "mlp.gate_proj", + "mlp.fc2": "mlp.down_proj", + "norm2": "mlp_ln", + }) + + +class GteNewModel(BertWithRope): + # for https://huggingface.co/Alibaba-NLP/new-impl + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "new.": "", + "layer": "layers", + "attention.qkv_proj": "attn.qkv_proj", + "attention.o_proj": "attn.out_proj", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # GteNewModel only gate_up_proj does not have bias. + # Hack method learned from vllm/model_executor/models/glm.py + for layer in self.encoder.layers: + layer.mlp.gate_up_proj.bias = None + layer.mlp.gate_up_proj.skip_bias_add = True + + def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]): + n = "mlp.up_gate_proj" + for name, weight in weights: + if n in name: + up, gate = weight.chunk(2, dim=0) + yield name.replace(n, "mlp.up_proj"), up + yield name.replace(n, "mlp.gate_proj"), gate + else: + yield name, weight + + def ignore_unnecessary_layers(self, + weights: Iterable[tuple[str, torch.Tensor]]): + for name, weight in weights: + if name.startswith("classifier"): + continue + yield name, weight + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = self.ignore_unnecessary_layers(weights) + weights = self.split_up_gate_proj(weights) + return super().load_weights(weights) + + +class SnowflakeGteNewModel(GteNewModel): + # for Snowflake/snowflake-arctic-embed-m-v2.0 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "layer": "layers", + "attention.qkv_proj": "attn.qkv_proj", + "attention.o_proj": "attn.out_proj", + }) + + +class JinaRobertaModel(BertWithRope): + # for https://huggingface.co/jinaai/jina-embeddings-v3 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "emb_ln": "embeddings.LayerNorm", + "mixer.Wqkv": "attn.qkv_proj", + "mixer.out_proj": "attn.out_proj", + "norm1": "attn_ln", + "mlp.fc1.": "mlp.up_proj.", + "mlp.fc2": "mlp.down_proj", + "norm2": "mlp_ln", + }) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return super().forward(input_ids=input_ids, + positions=position_ids, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + token_type_ids=token_type_ids) + + @torch.inference_mode() + def jina_merge_lora_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]): + # use for jina-embeddings-v3 + # Merge Lora weights into a single weight tensor. + # This is a temporary solution until we have a better way to handle + + scaling = self.config.lora_alpha / self.config.lora_rank + device = self.vllm_config.device_config.device + + weights = {name: weight for name, weight in weights} + + o = ".original" + a = ".0.lora_A" + b = ".0.lora_B" + + # text-matching + i = -1 + + for name in list(weights.keys()): + if o in name: + dtype = weights[name].dtype + shape = weights[name].shape + weight_name = name[:-len(o)] + + if "embeddings" in weight_name: + B = weights[weight_name + a][i].to(device).float() + A = weights[weight_name + b][i].to(device).float() + else: + B = weights[weight_name + b][i].to(device).float() + A = weights[weight_name + a][i].to(device).float() + + weight = (weights[weight_name + o].to(device) + + torch.matmul(B, A).view(shape) * scaling) + weight = weight.cpu().to(dtype) + + weights[weight_name.replace(".parametrizations", "")] = weight + + del weights[weight_name + o], weights[weight_name + + a], weights[weight_name + + b] + + return [(name, weight) for name, weight in weights.items()] + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + weights = self.jina_merge_lora_weights(weights) + return super().load_weights(weights) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py new file mode 100644 index 0000000..2b457fd --- /dev/null +++ b/vllm/model_executor/models/blip.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Minimal implementation of BlipVisionModel intended to be only used +within a vision language model.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import Blip2VisionConfig, BlipVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .interfaces import SupportsQuant + + +def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_blip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa +class BlipVisionEmbeddings(nn.Module): + + def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]): + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = get_blip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + position_embeds = self.position_embedding.to(target_dtype) + embeddings = embeddings + position_embeds[:, :embeddings.size(1), :] + + return embeddings + + +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Union[BlipVisionConfig, Blip2VisionConfig], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.projection", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.projection(out) + + return attn_output, None + + +class BlipMLP(nn.Module): + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class BlipEncoderLayer(nn.Module): + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + # fallback to sdpa attention if tp unavailable + self.self_attn = BlipAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = BlipMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`BlipEncoderLayer`]. + + Args: + config: BlipConfig + """ + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + BlipEncoderLayer(config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class BlipVisionModel(nn.Module, SupportsQuant): + config_class = BlipVisionConfig + main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embeddings = BlipVisionEmbeddings(config) + self.encoder = BlipEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(inputs_embeds=hidden_states) + + if self.post_layernorm is None: + return hidden_states + + return self.post_layernorm(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in BlipVisionModel + if (name.startswith("post_layernorm") + and self.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py new file mode 100644 index 0000000..27a9208 --- /dev/null +++ b/vllm/model_executor/models/blip2.py @@ -0,0 +1,728 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, + apply_chunking_to_forward) + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptIndexTargets, + PromptInsertion, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .blip import BlipVisionModel +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +# We use this internally as placeholders since there is no image token +# defined on the HuggingFace repo +_IMAGE_TOKEN_ID = 50265 + + +class Blip2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Blip2ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] + + +class Blip2QFormerMultiHeadAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of " + f"the number of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = (config.hidden_size // + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + kv_hidden_size = config.encoder_hidden_size + else: + kv_hidden_size = config.hidden_size + self.key = nn.Linear(kv_hidden_size, self.all_head_size) + self.value = nn.Linear(kv_hidden_size, self.all_head_size) + + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + if self.position_embedding_type != "absolute": + raise NotImplementedError("Unsupported position_embedding_type: " + f"{self.position_embedding_type}") + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + x = x.view(*x.size()[:-1], self.num_attention_heads, + self.attention_head_size) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ): + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_probs = torch.softmax(attention_scores * self.scaling, + dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view(*context_layer.size()[:-2], + self.all_head_size) + + return context_layer + + +class Blip2QFormerSelfOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + self.attention = Blip2QFormerMultiHeadAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=is_cross_attention, + prefix=f"{prefix}.attention", + ) + + self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> tuple[torch.Tensor]: + self_output = self.attention( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + attention_output = self.output(self_output, hidden_states) + + return attention_output + + +class Blip2QFormerIntermediate(nn.Module): + + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class Blip2QFormerOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: + super().__init__() + + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerLayer(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + layer_idx: int, + prefix: str = "", + ) -> None: + super().__init__() + + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.attention") + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=True, + prefix=f"{prefix}.crossattention") + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate_query = Blip2QFormerIntermediate( + config, prefix=f"{prefix}.intermediate_query") + self.output_query = Blip2QFormerOutput(config, + prefix=f"{prefix}.output_query") + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ): + attention_output = self.attention(hidden_states) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + query_attention_output = self.crossattention( + query_attention_output, + encoder_hidden_states=encoder_hidden_states, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + + return layer_output + + def feed_forward_chunk(self, + attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query( + self, attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class Blip2QFormerEncoder(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.layer = nn.ModuleList([ + Blip2QFormerLayer(config, + quant_config=quant_config, + cache_config=cache_config, + layer_idx=layer_idx, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ) -> torch.Tensor: + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + + hidden_states = layer_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025 +class Blip2QFormerModel(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.encoder = Blip2QFormerEncoder(config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + query_embeds: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + ) -> torch.Tensor: + query_length = query_embeds.shape[1] + + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) + + sequence_output = self.encoder( + embedding_output, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return sequence_output + + +class Blip2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Blip2Config) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + return hf_config.num_query_tokens + + +class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + +class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # HF processor always adds placeholders even when there's no image + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + image_token_id = vocab[""] + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, + info=Blip2ProcessingInfo, + dummy_inputs=Blip2DummyInputsBuilder) +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsQuant): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_model = BlipVisionModel(config.vision_config, quant_config) + + self.query_tokens = nn.Parameter( + torch.zeros(1, config.num_query_tokens, + config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.qformer") + + self.language_projection = nn.Linear( + config.qformer_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Blip2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + + return Blip2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + image_embeds = flatten_bn(image_embeds, concat=True) + + return Blip2ImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features(self, vision_model: BlipVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_model(pixel_values) + + return image_features + + def _process_image_pixels(self, + inputs: Blip2ImagePixelInputs) -> torch.Tensor: + assert self.vision_model is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_model, pixel_values) + + def _process_image_input(self, + image_input: Blip2ImageInputs) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + image_features = self._process_image_pixels(image_input) + + query_tokens = self.query_tokens.expand(image_features.shape[0], -1, + -1) + query_output = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_features, + ) + + return self.language_projection(query_output) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + _IMAGE_TOKEN_ID) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: + """Run forward pass for BLIP-2. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"Question: What's the content of the image? Answer:"`. + + Tokenizer outputs: + `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + dummy tokens (denoted as `50265`), resulting in: + `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`. + + We insert 32 tokens since it corresponds to the number of query + embeddings outputted by the Q-Former and inputted to the language model. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + Info: + [Blip2ImageInputs][] + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py new file mode 100644 index 0000000..9743512 --- /dev/null +++ b/vllm/model_executor/models/bloom.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py +# Copyright 2023 The vLLM team. +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BLOOM model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import BloomConfig +import os +import re + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm import _custom_ops as ops +from vllm.model_executor.utils import pad_weight, gemm_bank_conf + +from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BloomAttention(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.n_head + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + del position_ids # Unused. + qkv, _ = self.query_key_value(hidden_states) + # if os.environ.get('FA_PAD') == '1' and self.quant_method is None: + # qkv = qkv[...,:-32] + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + return output + + +class BloomMLP(nn.Module): + + def __init__( + self, + config: BloomConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + quant_config=quant_config, + ) + self.gelu_impl = get_act_fn("gelu") + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.dense_h_to_4h(x) + x = self.gelu_impl(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class BloomBlock(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + + self.input_layernorm = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.mlp = BloomMLP(config, quant_config) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attention_output = self.self_attention( + position_ids=position_ids, + hidden_states=layernorm_output, + ) + attention_output = attention_output + residual + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output) + residual + return output + + +@support_torch_compile +class BloomModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.embed_dim = config.hidden_size + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.word_embeddings_layernorm = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_epsilon) + + # Transformer blocks + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: BloomBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + + # Final Layer Norm + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' + self.use_fa_pad = os.environ.get('FA_PAD') == '1' + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None: + lay_key_words = [ + "self_attention.query_key_value.weight", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_4h_to_h.weight" + ] + combined_words = "|".join(lay_key_words) + + # lay_qkv_words = ["self_attention.query_key_value.weight"] + # qkv_words = "|".join(lay_qkv_words) + + # lay_qkv_bias_words = ["self_attention.query_key_value.bias"] + # qkv_bias_words = "|".join(lay_qkv_bias_words) + + for layername in loaded_params: + weight = params_dict[layername] + # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): + # weight.data = pad_weight(weight.data, 32) + + matches = re.findall(combined_words, layername) + if matches: + # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + # if self.use_fa_pad and (re.findall(qkv_words, layername)): + # if not gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1],-1) + return loaded_params + + +class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = BloomModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) + + +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.'): + name = 'transformer.' + name + yield name, tensor diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py new file mode 100644 index 0000000..74b18df --- /dev/null +++ b/vllm/model_executor/models/chameleon.py @@ -0,0 +1,1146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Any, Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, + ChameleonVQVAEConfig) + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, + SupportsQuant) +from .utils import (flatten_bn, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix, merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class ChameleonImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class ChameleonProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(ChameleonConfig) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + processor = self.get_hf_processor() + return processor.image_seq_length + + +class ChameleonDummyInputsBuilder( + BaseDummyInputsBuilder[ChameleonProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + config = self.info.get_hf_config() + + width = height = config.vq_config.resolution + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=width, + height=height, + num_images=num_images) + } + + +class ChameleonMultiModalProcessor( + BaseMultiModalProcessor[ChameleonProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor adds sep token for chat mode + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + sep_token_id = vocab[tokenizer.sep_token] # type: ignore + + return prompt_tokens + [sep_token_id] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + image_start_id = vocab[processor.image_start_token] + image_token_id = vocab[processor.image_token] + image_end_id = vocab[processor.image_end_token] + + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=PromptUpdateDetails.select_token_id( + [image_start_id] + image_tokens + [image_end_id], + embed_token_id=image_token_id, + ), + ) + ] + + +class ChameleonLayerNorm(nn.LayerNorm): + + def __init__(self, hidden_size, *args, **kwargs): + super().__init__(hidden_size, *args, **kwargs) + self.normalized_shape = (hidden_size[-1], ) + + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.bias, + {"weight_loader": row_parallel_weight_loader}) + + def forward(self, hidden_states): + hidden_states = F.layer_norm(hidden_states, + self.normalized_shape, + None, + None, + eps=1e-5) + hidden_states = hidden_states * self.weight + self.bias + return hidden_states + + +# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP +class ChameleonMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa +class ChameleonAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 4096, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim)) + self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim)) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # reshape for layernorm + q = q.reshape(-1, self.num_heads, self.head_dim) + k = k.reshape(-1, self.num_kv_heads, self.head_dim) + q = self.q_norm(q) + k = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class ChameleonDecoderLayer(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = ChameleonAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = ChameleonMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class ChameleonSwinDecoderLayer(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = ChameleonAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = ChameleonMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa +class ChameleonVQVAEVectorQuantizer(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.re_embed = self.num_embeddings + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - + 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, + self.embedding.weight.transpose(0, 1))) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = self.embedding(min_encoding_indices).view( + hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state)** + 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach())**2) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - + hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, + 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa +class ChameleonVQVAEEncoderConvDownsample(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, + pad=(0, 1, 0, 1), + mode="constant", + value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa +class ChameleonVQVAEEncoderResnetBlock(nn.Module): + + def __init__( + self, + config: ChameleonVQVAEConfig, + in_channels: int, + out_channels=None, + conv_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None \ + else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=32, + num_channels=out_channels, + eps=1e-6, + affine=True) + self.dropout = torch.nn.Dropout(config.dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa +class ChameleonVQVAEEncoderAttnBlock(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm(hidden_states) + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) + + # compute attention + batch_size, channels, height, width = query_states.shape + query_states = query_states.reshape(batch_size, channels, + height * width).permute(0, 2, 1) + key_states = key_states.reshape(batch_size, channels, height * width) + attn_weights = torch.bmm(query_states, key_states) + attn_weights = attn_weights * (int(channels)**(-0.5)) + attn_weights = F.softmax(attn_weights, dim=2) + + # attend to values + value_states = value_states.reshape(batch_size, channels, + height * width) + attn_weights = attn_weights.permute(0, 2, 1) + attn_output = torch.bmm(value_states, + attn_weights).reshape(batch_size, channels, + height, width) + + attn_output = self.proj_out(attn_output) + return residual + attn_output + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa +class ChameleonVQVAEEncoder(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + + self.conv_in = torch.nn.Conv2d(in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_channel_multiplier = (1, ) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + )) + block_in = block_out + if (config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla"): + attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock( + block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, + num_channels=block_in, + eps=1e-6, + affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * latent_channels if double_latent else latent_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values: torch.Tensor): + pixel_values = pixel_values.to(self.conv_in.weight.dtype) + + # downsampling + hidden_states = [self.conv_in(pixel_values)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_state = self.down[i_level].block[i_block]( + hidden_states[-1]) + if len(self.down[i_level].attn) > 0: + hidden_state = self.down[i_level].attn[i_block]( + hidden_state) + hidden_states.append(hidden_state) + if i_level != self.num_resolutions - 1: + hidden_states.append(self.down[i_level].downsample( + hidden_states[-1])) + + # middle + last_hidden_state = hidden_states[-1] + last_hidden_state = self.mid.block_1(last_hidden_state) + last_hidden_state = self.mid.attn_1(last_hidden_state) + last_hidden_state = self.mid.block_2(last_hidden_state) + + # end + last_hidden_state = self.norm_out(last_hidden_state) + last_hidden_state *= torch.sigmoid(last_hidden_state) + last_hidden_state = self.conv_out(last_hidden_state) + return last_hidden_state + + +# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa +class ChameleonVQVAE(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, + config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, + config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode( + self, pixel_values: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa +class ChameleonImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map: dict[str, int]): + self.vocab_map = vocab_map + self.image_token_id = vocab_map.get("") + + @cached_property + def val2name(self): + return {v: k for k, v in self.vocab_map.items()} + + @cached_property + def image_tokens(self): + return sorted([ + val for name, val in self.vocab_map.items() + if name.startswith("IMGIMG") + ]) + + @cached_property + def bpe2img(self): + img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} + + def remap(old_name: str) -> str: + return "".join( + img_tkn_chr_mapping.get(c, c) + for c in old_name[len("IMGIMG"):-1]) + + return { + tok: int(remap(self.val2name[tok])) + for tok in self.image_tokens + } + + @cached_property + def img2bpe(self): + return {v: k for k, v in self.bpe2img.items()} + + @cached_property + def bpe2img_search_tensors(self): + return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor( + sorted(self.bpe2img.values())) + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class ChameleonModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + self.vocabulary_mapping = ChameleonImageVocabularyMapping( + config.vocabulary_map) + decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ + else ChameleonSwinDecoderLayer + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.vqmodel = ChameleonVQVAE(config.vq_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + """ + batch_size = pixel_values.shape[0] + _, _, image_toks = self.vqmodel.encode(pixel_values) + bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) + bpe_toks = bpe_toks.view(batch_size, -1) + return bpe_toks + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + ChameleonMultiModalProcessor, + info=ChameleonProcessingInfo, + dummy_inputs=ChameleonDummyInputsBuilder) +class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + self.model = ChameleonModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + vq_config: ChameleonVQVAEConfig = self.config.vq_config + expected_dims = (3, vq_config.resolution, vq_config.resolution) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + + return ChameleonImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def get_language_model(self) -> torch.nn.Module: + return self.model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens(image_input["data"].to( + self.config.torch_dtype)) + vision_embeddings = self.model.get_input_embeddings(image_tokens) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.model.vocabulary_mapping.image_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + # Disallow image tokens which does not include special + # begin-image and end-image tokens + if logits is not None: + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, image_tokens] = torch.finfo(logits.dtype).min + + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + use_default_weight_loading = False + if "vqmodel" in name: + if self.model.vqmodel is not None: + # We only do sharding for language model and + # not vqvae for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if use_default_weight_loading and name in params_dict: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py new file mode 100644 index 0000000..dcd4ecf --- /dev/null +++ b/vllm/model_executor/models/chatglm.py @@ -0,0 +1,543 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights.""" +import json +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +import os +import re + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import ChatGLMConfig + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +from vllm import _custom_ops as ops +from vllm.model_executor.utils import pad_weight, gemm_bank_conf + + +class GLMAttention(nn.Module): + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 + rope_ratio = getattr(config, "rope_ratio", 1.0) + max_positions = getattr(config, "seq_length", 8192) + # NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False, + # which is equivalent to is_neox_style=True + is_neox_style = not config.original_rope + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim // 2, + max_position=max_positions, + base=10000 * rope_ratio, + is_neox_style=is_neox_style, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + # if os.environ.get('FA_PAD') == '1' and self.quant_method is None: + # qkv = qkv[...,:-32] + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + context_layer = self.attn(q, k, v) + attn_output, _ = self.dense(context_layer) + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__( + self, + config: ChatGLMConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. + self.dense_h_to_4h = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=config.add_bias_linear, + quant_config=quant_config, + prefix=f"{prefix}.dense_h_to_4h", + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h", + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = GLMAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__( + self, + config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.start_layer, self.end_layer, self.layers = make_layers( + self.num_layers, + lambda prefix: GLMBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if self.post_layer_norm: + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> Union[torch.Tensor, IntermediateTensors]: + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states=hidden_states, + position_ids=position_ids) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +@support_torch_compile +class ChatGLMModel(nn.Module, SupportsQuant): + packed_modules_mapping = { + "linear_proj.merged_proj": + ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embedding") + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config, + cache_config, + quant_config, + prefix=f"{prefix}.encoder") + + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.output_layer") + + self.make_empty_intermediate_tensors = ( + self.encoder.make_empty_intermediate_tensors) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' + self.use_fa_pad = os.environ.get('FA_PAD') == '1' + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # Run encoder. + hidden_states = self.encoder( + hidden_states=hidden_states, + position_ids=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), + ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "rotary_pos_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None: + lay_key_words = [ + "self_attention.query_key_value.weight", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_4h_to_h.weight", + ] + combined_words = "|".join(lay_key_words) + + # lay_qkv_words = ["self_attention.query_key_value.weight"] + # qkv_words = "|".join(lay_qkv_words) + + # lay_qkv_bias_words = ["self_attention.query_key_value.bias"] + # qkv_bias_words = "|".join(lay_qkv_bias_words) + + for layername in loaded_params: + weight = params_dict[layername] + if "lm_head.weight" in layername and weight.shape[1] == 4096: + lay_key_words.append("lm_head.weight") + combined_words = "|".join(lay_key_words) + os.environ['LM_NN'] = '1' + else: + os.environ['LM_NN'] = '0' + # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): + # weight.data = pad_weight(weight.data, 32) + + matches = re.findall(combined_words, layername) + if matches: + # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + # if self.use_fa_pad and (re.findall(qkv_words, layername)): + # if not gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1], -1) + + return loaded_params + + +class ChatGLMBaseModel(nn.Module): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".word_embeddings": ""}, ) + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[ChatGLMModel] = ChatGLMModel, + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.lora_config = lora_config + self.multimodal_config = multimodal_config + + self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) + self.transformer = transformer_type(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + if self.config.tie_word_embeddings: + self.transformer.output_layer.weight = ( + self.transformer.embedding.weight) + self.lm_head = self.transformer.output_layer + self.logits_processor = LogitsProcessor(config.padded_vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsQuant): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + if hasattr(config, "vision_config"): + hf_overrides = {"architectures": ["GLM4VForCausalLM"]} + raise RuntimeError( + "The configuration of this model indicates that it supports " + "vision inputs, but you instantiated the text-only version " + "of this model. Please use the vision model by setting " + f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py new file mode 100644 index 0000000..dcab008 --- /dev/null +++ b/vllm/model_executor/models/clip.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Minimal implementation of CLIPVisionModel intended to be only used +within a vision language model.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers import CLIPVisionConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsQuant + +from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs + + +class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return self.get_patch_grid_length()**2 + 1 + + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: + image_size, patch_size = self.get_image_size(), self.get_patch_size() + assert image_size % patch_size == 0 + return image_size // patch_size + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +class CLIPVisionEmbeddings(nn.Module): + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + assert self.image_size % self.patch_size == 0 + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + self.register_buffer("position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.out_proj(out) + + return attn_output, None + + +class CLIPMLP(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = CLIPAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward( + self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [inputs_embeds] + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states) + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) + + return encoder_outputs + + +class CLIPVisionModel(nn.Module, SupportsQuant): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + require_post_norm=require_post_norm, + prefix=f"{prefix}.vision_model") + + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + return self.vision_model(pixel_values, feature_sample_layers) + + @property + def device(self): + return next(self.parameters()).device + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py new file mode 100644 index 0000000..817c6bb --- /dev/null +++ b/vllm/model_executor/models/commandr.py @@ -0,0 +1,471 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is based on the LLama model definition file in transformers +"""PyTorch Cohere model.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import CohereConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name, + row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +@torch.compile(backend=current_platform.simple_compile_backend) +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + +class LayerNorm(nn.Module): + + def __init__(self, param_shape=None, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(param_shape)) + self.variance_epsilon = eps + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + + def forward(self, hidden_states, residuals=None): + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere +class CohereMLP(nn.Module): + + def __init__( + self, + config: CohereConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class CohereAttention(nn.Module): + + def __init__( + self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = getattr( + config, "model_max_length", None) or getattr( + config, "max_position_embeddings", 8192) + self.rope_theta = config.rope_theta + self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + is_neox_style=False, + ) + + # Model v2 has interleaved sliding windows, v1 does not + interleaved_sliding_window = getattr(config, + "interleaved_sliding_window", + None) + self.v1 = interleaved_sliding_window is None + + layer_idx = extract_layer_index(prefix) + layer_has_sliding_window = ( + getattr(config, "sliding_window_pattern", False) + and (layer_idx + 1) % self.config.sliding_window_pattern != 0) + + self.sliding_window = (interleaved_sliding_window + if layer_has_sliding_window else None) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn") + if self.use_qk_norm: + self.q_norm = LayerNorm(param_shape=(self.num_heads, + self.head_dim), + eps=config.layer_norm_eps) + self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, + self.head_dim), + eps=config.layer_norm_eps) + + def _apply_qk_norm(self, q, k): + q = q.view(*q.shape[:-1], -1, self.head_dim) + k = k.view(*k.shape[:-1], -1, self.head_dim) + q, _ = self.q_norm(q) + k, _ = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) + if self.v1 or self.sliding_window: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CohereDecoderLayer(nn.Module): + + def __init__(self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + + self.mlp = CohereMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states_attention = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states_mlp = self.mlp(hidden_states) + # Add everything together + hidden_states = residual + hidden_states_attention + hidden_states_mlp + + return hidden_states, residual + + +@support_torch_compile +class CohereModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.quant_config = quant_config + + self.config = config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CohereDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + embedding_modules = {"embed_tokens": "input_embeddings"} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + # currently all existing command R models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.quant_config = quant_config + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=config.logit_scale) + self.model = CohereModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + is_not_lora = hasattr(self.model.embed_tokens, 'weight') + if is_not_lora: + logits = self.logits_processor(self.model.embed_tokens, + hidden_states, sampling_metadata) + else: + logits = self.logits_processor(self.model.embed_tokens.base_layer, + hidden_states, sampling_metadata) + + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"]) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py new file mode 100644 index 0000000..552c4b0 --- /dev/null +++ b/vllm/model_executor/models/config.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class VerifyAndUpdateConfig: + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + raise NotImplementedError + + +class GteNewModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NewConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + + +class JinaRobertaModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + if config.position_embedding_type == "rotary": + assert config.__class__.__name__ == "XLMRobertaFlashConfig" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": getattr(config, "rope_theta", config.rotary_emb_base), + "rope_scaling": getattr(config, "rope_scaling", None) + } + + +class NomicBertModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "NomicBertConfig" + assert config.activation_function in ["swiglu", "gelu"] + config.position_embedding_type = getattr(config, + "position_embedding_type", + "rope") + + if config.activation_function == "swiglu": + config.hidden_act = "silu" + else: + config.hidden_act = config.activation_function + + assert (config.mlp_fc1_bias == config.mlp_fc2_bias == + config.qkv_proj_bias) + config.bias = config.qkv_proj_bias + + assert config.rotary_emb_scale_base is None + assert not config.rotary_emb_interleaved + + config.layer_norm_eps = config.layer_norm_epsilon + config.intermediate_size = config.n_inner + config.hidden_size = config.n_embd + config.num_hidden_layers = config.n_layer + + head_dim = config.hidden_size // config.num_attention_heads + rotary_emb_dim = head_dim * config.rotary_emb_fraction + max_trained_positions = getattr(config, "max_trained_positions", 2048) + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": rotary_emb_dim, + "max_position": max_trained_positions, + "base": getattr(config, "rope_theta", config.rotary_emb_base), + "rope_scaling": getattr(config, "rope_scaling", None) + } + + # we ignore config.rotary_scaling_factor so that for datasets shorter + # than max_trained_positions 2048, the results are consistent + # with SentenceTransformer. + # The context extension uses vllm style rope_theta and rope_scaling. + # See #17785 #18755 + if (not vllm_config.model_config.hf_overrides + and vllm_config.model_config.original_max_model_len is None): + # Default + # Reset max_model_len to max_trained_positions. + # nomic-embed-text-v2-moe the length is set to 512 + # by sentence_bert_config.json. + max_model_len_before = vllm_config.model_config.max_model_len + max_model_len = min(vllm_config.model_config.max_model_len, + max_trained_positions) + + vllm_config.recalculate_max_model_len(max_model_len) + logger.warning( + "Nomic context extension is disabled. " + "Changing max_model_len from %s to %s. " + "To enable context extension, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html", + max_model_len_before, vllm_config.model_config.max_model_len) + else: + # We need to re-verify max_model_len to avoid lengths + # greater than position_embedding. + model_config = vllm_config.model_config + hf_text_config = model_config.hf_text_config + + if isinstance(model_config.hf_overrides, dict): + # hf_overrides_kw + max_model_len = model_config.hf_overrides.get( + "max_model_len", vllm_config.model_config.max_model_len) + else: + # hf_overrides_fn + # This might be overridden by sentence_bert_config.json. + max_model_len = vllm_config.model_config.max_model_len + + # reset hf_text_config for recalculate_max_model_len. + if hasattr(hf_text_config, "max_model_len"): + delattr(hf_text_config, "max_model_len") + hf_text_config.max_position_embeddings = max_trained_positions + hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"] + + # The priority of sentence_bert_config.json is higher + # than max_position_embeddings + encoder_config = deepcopy(model_config.encoder_config) + encoder_config.pop("max_seq_length", None) + model_config.encoder_config = encoder_config + + vllm_config.recalculate_max_model_len(max_model_len) + + +class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + is_original_qwen3_reranker = getattr(config, + "is_original_qwen3_reranker", + False) + + if not is_original_qwen3_reranker: + return + + tokens = getattr(config, "classifier_from_token", None) + assert tokens is not None and len(tokens) == 2, \ + ("Try loading the original Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") + vllm_config.model_config.hf_config.method = "from_2_way_softmax" + + +class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "GteConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + + +MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { + "GteModel": SnowflakeGteNewModelConfig, + "GteNewModel": GteNewModelConfig, + "NomicBertModel": NomicBertModelConfig, + "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, + "XLMRobertaModel": JinaRobertaModelConfig, +} diff --git a/vllm/model_executor/models/constant_size_cache.py b/vllm/model_executor/models/constant_size_cache.py new file mode 100644 index 0000000..f03c58a --- /dev/null +++ b/vllm/model_executor/models/constant_size_cache.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from typing import Any + +import torch + +from vllm.attention.backends.utils import PAD_SLOT_ID + + +class ConstantSizeCache(ABC): + """ + Abstract base class for managing constant size caches + like Mamba and Minimax. + """ + + def __init__(self, max_batch_size: int): + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the cache + self.cache_indices_mapping: dict[str, dict[int, int]] = {} + self.free_cache_indices = list(range(max_batch_size)) + + @property + @abstractmethod + def cache(self) -> Any: + """Return the underlying cache tensor(s)""" + pass + + @abstractmethod + def _copy_cache(self, from_index: int, to_index: int): + """Copy cache data from one index to another""" + pass + + def current_run_tensors(self, **kwargs) -> tuple: + """ + Return the tensors for the current run's conv and ssm state. + """ + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + + self._release_finished_requests(finished_requests_ids) + state_indices = self._prepare_current_run_cache( + request_ids_to_seq_ids, finished_requests_ids) + + state_indices_tensor = torch.as_tensor(state_indices, + dtype=torch.int32, + device="cuda") + cache_tensors = self.cache + else: + # CUDA graph capturing runs + cache_tensors, state_indices_tensor = kwargs[ + "seqlen_agnostic_capture_inputs"] + + return (cache_tensors, state_indices_tensor) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant state_indices into the CUDA graph input buffer + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + assert "seqlen_agnostic_capture_inputs" in input_buffers + _, input_state_indices_buffer = input_buffers[ + "seqlen_agnostic_capture_inputs"] + + self._release_finished_requests(finished_requests_ids) + state_indices = self._prepare_current_run_cache( + request_ids_to_seq_ids, finished_requests_ids) + cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( + state_indices) + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + + input_state_indices_buffer.copy_( + torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Cache during the CUDA graph replay + runs. + """ + state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device="cuda") + return (self.cache, state_indices_tensor) + + def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, + finished_requests_ids) -> int: + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + if cur_rid in finished_requests_ids: + # set as pad, do not allocate destination index + return PAD_SLOT_ID + elif cur_rid not in self.cache_indices_mapping: + destination_index = self.free_cache_indices.pop() + self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} + return destination_index + elif seq_id not in (seq_ids2indices := + self.cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened, so we copy the + # existing cache into the siblings seq_ids caches + index_exists = next(iter(seq_ids2indices.values())) + # case of decoding n>1, copy prefill cache to decoding indices + destination_index = self.free_cache_indices.pop() + self._copy_cache(from_index=index_exists, + to_index=destination_index) + self.cache_indices_mapping[cur_rid][seq_id] = destination_index + return destination_index + else: + return self.cache_indices_mapping[cur_rid][seq_id] + + def _prepare_current_run_cache( + self, request_ids_to_seq_ids: dict[str, list[int]], + finished_requests_ids: list[str]) -> list[int]: + return [ + self._assign_seq_id_to_cache_index(req_id, seq_id, + finished_requests_ids) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + + def _release_finished_requests(self, + finished_seq_groups_req_ids: list[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.cache_indices_mapping: + for seq_id in self.cache_indices_mapping[req_id]: + self.free_cache_indices.append( + self.cache_indices_mapping[req_id][seq_id]) + self.cache_indices_mapping.pop(req_id) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py new file mode 100644 index 0000000..7a4dd69 --- /dev/null +++ b/vllm/model_executor/models/dbrx.py @@ -0,0 +1,472 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.dbrx import DbrxConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class DbrxRouter(nn.Module): + """A Router implementation for DBRX that returns logits for each expert + per token. + """ + + def __init__( + self, + config: DbrxConfig, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.ffn_config.moe_num_experts + self.d_model = config.d_model + self.layer = ReplicatedLinear( + self.d_model, + self.num_total_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.layer(hidden_states) + return router_logits + + +class DbrxExperts(FusedMoE): + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + prefix: str = "", + ): + super().__init__( + num_experts=config.ffn_config.moe_num_experts, + top_k=config.ffn_config.moe_top_k, + hidden_size=config.d_model, + intermediate_size=config.ffn_config.ffn_hidden_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=get_tensor_model_parallel_world_size(), + prefix=prefix, + ) + self.config = config + self.d_model = config.d_model + self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // + self.tp_size) + + # Define custom weight loader for dbrx model + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, param_name: str): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + # DBRX uses GLU for each experts. + # GLU has 3 linear layers: w1, v1 and w2. + if weight_name.endswith("w1"): + if param_name.endswith("weight"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :] + elif param_name.endswith("weight_scale"): + param_data[:, 0] = loaded_weight + else: + param_data = loaded_weight + if weight_name.endswith("v1"): + if param_name.endswith("weight"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, shard_size:2 * + shard_size, :] = loaded_weight[:, shard, :] + elif param_name.endswith("weight_scale"): + param_data[:, 1] = loaded_weight + else: + param_data[:] = loaded_weight + if weight_name.endswith("w2"): + if param_name.endswith("weight"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ).transpose(1, 2) + param_data[:] = loaded_weight[:, :, shard] + else: + param_data[:] = loaded_weight + + +class DbrxMoE(nn.Module): + """A tensor-parallel MoE implementation for DBRX. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = DbrxRouter(config, self.params_dtype) + + self.experts = DbrxExperts(config=config, + quant_config=quant_config, + params_dtype=self.params_dtype, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.d_model) + # router_logits: (num_tokens, n_experts) + router_logits = self.router(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class DbrxAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.total_num_kv_heads = config.attn_config.kv_n_heads + self.clip_qkv = config.attn_config.clip_qkv + self.rope_theta = config.attn_config.rope_theta + self.max_position = config.max_seq_len + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + self.tp_size = tp_world_size + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + hidden_states, _ = self.out_proj(attn_output) + return hidden_states + + +class DbrxFusedNormAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.d_model = config.d_model + self.attn = DbrxAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.norm_1 = nn.LayerNorm(self.d_model) + self.norm_2 = nn.LayerNorm(self.d_model) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + ) + hidden_states = residual + x + residual = hidden_states + hidden_states = self.norm_2(hidden_states) + return hidden_states, residual + + +class DbrxBlock(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.norm_attn_norm = DbrxFusedNormAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.norm_attn_norm") + self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + hidden_states, residual = self.norm_attn_norm( + position_ids=position_ids, + hidden_states=hidden_states, + ) + hidden_states = self.ffn(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DbrxModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.quant_config = quant_config + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: DbrxBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.blocks", + ) + self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, + nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for block in self.blocks[self.start_layer:self.end_layer]: + hidden_states = block(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + expert_params_mapping = [( + "w13" if weight_name in ["w1", "v1"] else "w2", + f"mlp.{weight_name}", + ) for weight_name in ["w1", "v1", "w2"]] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if name.endswith(("w1", "w2", "v1")): + name = name + "_weight" + for param_name, weight_name in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, weight_name, name) + break + + else: + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DbrxForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + if config.tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models.") + self.quant_config = quant_config + self.unpadded_vocab_size = config.vocab_size + self.transformer = DbrxModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py new file mode 100644 index 0000000..ca19621 --- /dev/null +++ b/vllm/model_executor/models/deepseek.py @@ -0,0 +1,496 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Deepseek model.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +import vllm.envs as envs + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + + self.experts = nn.ModuleList([ + DeepseekMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() + + self.gate = ReplicatedLinear(config.hidden_size, + self.n_routed_experts, + bias=False, + quant_config=None) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + if envs.VLLM_USE_NN: + self.w1 = self.w1.permute(0,2,1).contiguous() + for expert, w in zip(self.experts, self.w1): + expert.gate_up_proj.weight.data = w.permute(1,0) + + self.w2 = self.w2.permute(0, 2, 1).contiguous() + for expert, w in zip(self.experts, self.w2): + expert.down_proj.weight.data = w.permute(1, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DeepseekForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = DeepseekModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 0000000..26e1b88 --- /dev/null +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,662 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import re + +from collections.abc import Iterable +from typing import Iterable, Optional + + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.compilation.decorators import support_torch_compile +from .deepseek_v2 import (DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name) +from .interfaces import SupportsPP +from .utils import maybe_prefix +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, + cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class DeepSeekMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + DeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + +@support_torch_compile +class DeepSeekMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.quant_method = None + if quant_config is not None: + self.quant_method = quant_config.get_name() + os.environ['LLAMA_NN'] = '0' + os.environ['LM_NN'] = '0' + # The AWQ layer of MTP uses BlockInt8W8A8. + if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin": + vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128]) + + self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None: + lay_key_words = [ + "self_attn.eh_proj.weight", + "self_attn.q_proj.weight", + "self_attn.q_a_proj.weight", + "self_attn.q_b_proj.weight", + "self_attn.kv_a_proj_with_mqa.weight", + "self_attn.kv_b_proj.weight", + "self_attn.o_proj.weight", + "mlp.gate_up_proj.weight", + "mlp.down_proj.weight", + "mlp.gate.weight", + "shared_experts.gate_up_proj.weight", + "shared_experts.down_proj.weight", + "shared_head.head.weight", + ] + + combined_words = "|".join(lay_key_words) + + for layername in loaded_params: + weight = params_dict[layername] + matches = re.findall(combined_words, layername) + if matches: + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1],-1) + + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + spec_layer_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + return name + + + +# # SPDX-License-Identifier: Apache-2.0 +# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# import os +# import re + +# from collections.abc import Iterable +# from typing import Iterable, Optional + + +# import torch +# import torch.nn as nn +# from transformers import PretrainedConfig + +# from vllm.config import CacheConfig, ModelConfig, VllmConfig +# from vllm.model_executor.layers.fused_moe import FusedMoE +# from vllm.model_executor.layers.layernorm import RMSNorm +# from vllm.model_executor.layers.logits_processor import LogitsProcessor +# from vllm.model_executor.layers.quantization import QuantizationConfig +# from vllm.model_executor.layers.vocab_parallel_embedding import ( +# ParallelLMHead, VocabParallelEmbedding) +# from vllm.model_executor.model_loader.weight_utils import default_weight_loader +# from vllm.model_executor.sampling_metadata import SamplingMetadata +# from vllm.sequence import IntermediateTensors +# from vllm.compilation.decorators import support_torch_compile +# from .deepseek_v2 import (DeepseekV2DecoderLayer, +# get_spec_layer_idx_from_weight_name) +# from .interfaces import SupportsPP +# from .utils import maybe_prefix +# from vllm import _custom_ops as ops +# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config + + +# class SharedHead(nn.Module): + +# def __init__( +# self, +# config: PretrainedConfig, +# quant_config: Optional[QuantizationConfig] = None, +# ) -> None: +# super().__init__() +# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) +# self.head = ParallelLMHead(config.vocab_size, +# config.hidden_size, +# quant_config=quant_config) + +# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: +# return self.norm(hidden_states) + + +# class DeepSeekMultiTokenPredictorLayer(nn.Module): + +# def __init__( +# self, +# config: PretrainedConfig, +# prefix: str, +# model_config: ModelConfig, +# cache_config: Optional[CacheConfig] = None, +# quant_config: Optional[QuantizationConfig] = None, +# ) -> None: +# super().__init__() +# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) +# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) +# self.eh_proj = nn.Linear(config.hidden_size * 2, +# config.hidden_size, +# bias=False) +# self.shared_head = SharedHead(config=config, quant_config=quant_config) +# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, +# cache_config, quant_config) + +# def forward( +# self, +# input_ids: torch.Tensor, +# positions: torch.Tensor, +# previous_hidden_states: torch.Tensor, +# inputs_embeds: Optional[torch.Tensor] = None, +# spec_step_index: int = 0, +# ) -> torch.Tensor: +# assert inputs_embeds is not None +# # masking inputs at position 0, as not needed by MTP +# inputs_embeds[positions == 0] = 0 +# inputs_embeds = self.enorm(inputs_embeds) +# previous_hidden_states = self.hnorm(previous_hidden_states) + +# hidden_states = self.eh_proj( +# torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + +# hidden_states, residual = self.mtp_block(positions=positions, +# hidden_states=hidden_states, +# residual=None) +# hidden_states = residual + hidden_states +# return hidden_states + + +# class DeepSeekMultiTokenPredictor(nn.Module): + +# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): +# super().__init__() +# config = vllm_config.model_config.hf_config +# self.mtp_start_layer_idx = config.num_hidden_layers +# self.num_mtp_layers = config.num_nextn_predict_layers +# # to map the exact layer index from weights +# self.layers = torch.nn.ModuleDict({ +# str(idx): +# DeepSeekMultiTokenPredictorLayer( +# config, +# f"{prefix}.layers.{idx}", +# model_config=vllm_config.model_config, +# cache_config=vllm_config.cache_config, +# quant_config=vllm_config.quant_config, +# ) +# for idx in range(self.mtp_start_layer_idx, +# self.mtp_start_layer_idx + self.num_mtp_layers) +# }) +# self.embed_tokens = VocabParallelEmbedding( +# config.vocab_size, +# config.hidden_size, +# ) +# self.logits_processor = LogitsProcessor(config.vocab_size) + +# def forward( +# self, +# input_ids: torch.Tensor, +# positions: torch.Tensor, +# previous_hidden_states: torch.Tensor, +# inputs_embeds: Optional[torch.Tensor] = None, +# spec_step_idx: int = 0, +# ) -> torch.Tensor: +# if inputs_embeds is None: +# inputs_embeds = self.embed_tokens(input_ids) +# current_step_idx = (spec_step_idx % self.num_mtp_layers) +# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( +# input_ids, +# positions, +# previous_hidden_states, +# inputs_embeds, +# current_step_idx, +# ) + +# def compute_logits( +# self, +# hidden_states: torch.Tensor, +# sampling_metadata: SamplingMetadata, +# spec_step_idx: int = 0, +# ) -> torch.Tensor: +# current_step_idx = (spec_step_idx % self.num_mtp_layers) +# mtp_layer = self.layers[str(self.mtp_start_layer_idx + +# current_step_idx)] +# logits = self.logits_processor(mtp_layer.shared_head.head, +# mtp_layer.shared_head(hidden_states), +# sampling_metadata) +# return logits + +# @support_torch_compile +# class DeepSeekMTP(nn.Module, SupportsPP): + +# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): +# super().__init__() +# self.config = vllm_config.model_config.hf_config +# quant_config = vllm_config.quant_config + +# self.quant_method = None +# if quant_config is not None: +# self.quant_method = quant_config.get_name() +# os.environ['LLAMA_NN'] = '0' +# os.environ['LM_NN'] = '0' +# # The AWQ layer of MTP uses BlockInt8W8A8. +# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin": +# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128]) + +# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, +# prefix=maybe_prefix( +# prefix, "model")) +# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + + +# def forward( +# self, +# input_ids: torch.Tensor, +# positions: torch.Tensor, +# previous_hidden_states: torch.Tensor, +# intermediate_tensors: Optional[IntermediateTensors] = None, +# inputs_embeds: Optional[torch.Tensor] = None, +# spec_step_idx: int = 0, +# ) -> torch.Tensor: +# hidden_states = self.model(input_ids, positions, +# previous_hidden_states, inputs_embeds, +# spec_step_idx) +# return hidden_states + +# def compute_logits( +# self, +# hidden_states: torch.Tensor, +# sampling_metadata: SamplingMetadata, +# spec_step_idx: int = 0, +# ) -> Optional[torch.Tensor]: +# return self.model.compute_logits(hidden_states, sampling_metadata, +# spec_step_idx) + +# def load_weights(self, weights: Iterable[tuple[str, +# torch.Tensor]]) -> set[str]: +# stacked_params_mapping = [ +# ("gate_up_proj", "gate_proj", 0), +# ("gate_up_proj", "up_proj", 1), +# ] + +# expert_params_mapping = FusedMoE.make_expert_params_mapping( +# ckpt_gate_proj_name="gate_proj", +# ckpt_down_proj_name="down_proj", +# ckpt_up_proj_name="up_proj", +# num_experts=self.config.n_routed_experts) + +# params_dict = dict(self.named_parameters()) +# loaded_params: set[str] = set() +# for name, loaded_weight in weights: +# if "rotary_emb.inv_freq" in name: +# continue +# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) +# if spec_layer is None: +# continue +# name = self._rewrite_spec_layer_name(spec_layer, name) +# for (param_name, weight_name, shard_id) in stacked_params_mapping: +# # Skip non-stacked layers and experts (experts handled below). +# if weight_name not in name: +# continue +# # We have mlp.experts[0].gate_proj in the checkpoint. +# # Since we handle the experts below in expert_params_mapping, +# # we need to skip here BEFORE we update the name, otherwise +# # name will be updated to mlp.experts[0].gate_up_proj, which +# # will then be updated below in expert_params_mapping +# # for mlp.experts[0].gate_gate_up_proj, which breaks load. +# if (("mlp.experts." in name) and name not in params_dict): +# continue +# name = name.replace(weight_name, param_name) +# # Skip loading extra bias for GPTQ models. +# if name.endswith(".bias") and name not in params_dict: +# continue + +# param = params_dict[name] +# weight_loader = param.weight_loader +# weight_loader(param, loaded_weight, shard_id) +# break +# else: +# for mapping in expert_params_mapping: +# param_name, weight_name, expert_id, shard_id = mapping +# if weight_name not in name: +# continue +# name = name.replace(weight_name, param_name) + +# param = params_dict[name] +# weight_loader = param.weight_loader +# weight_loader(param, +# loaded_weight, +# name, +# shard_id=shard_id, +# expert_id=expert_id) +# break +# else: +# # Skip loading extra bias for GPTQ models. +# if name.endswith(".bias") and name not in params_dict: +# continue + +# # According to DeepSeek-V3 Technical Report, MTP modules +# # shares embedding layer. We only load the first weights. +# if (spec_layer != self.model.mtp_start_layer_idx +# and ".layers" not in name): +# continue + +# param = params_dict[name] +# weight_loader = getattr(param, "weight_loader", +# default_weight_loader) +# weight_loader(param, loaded_weight) +# loaded_params.add(name) + +# if self.use_llama_nn and self.quant_method is None: +# lay_key_words = [ +# "self_attn.eh_proj.weight", +# "self_attn.q_proj.weight", +# "self_attn.q_a_proj.weight", +# "self_attn.q_b_proj.weight", +# "self_attn.kv_a_proj_with_mqa.weight", +# "self_attn.kv_b_proj.weight", +# "self_attn.o_proj.weight", +# "mlp.gate_up_proj.weight", +# "mlp.down_proj.weight", +# "mlp.gate.weight", +# "shared_experts.gate_up_proj.weight", +# "shared_experts.down_proj.weight", +# "shared_head.head.weight", +# ] + +# combined_words = "|".join(lay_key_words) + +# for layername in loaded_params: +# weight = params_dict[layername] +# matches = re.findall(combined_words, layername) +# if matches: +# _weight = torch.zeros_like(weight.data) +# ori_shape =_weight.shape + +# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) +# weight.data.copy_(_weight) + +# weight.data=weight.data.reshape(ori_shape[1],-1) + +# return loaded_params + +# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: +# """ +# Rewrite the weight name to match the format of the original model. +# Add .mtp_block for modules in transformer layer block for spec layer +# and rename shared layer weights to be top level. +# """ +# spec_layer_weight_names = [ +# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" +# ] +# shared_weight_names = ["embed_tokens"] +# spec_layer_weight = False +# shared_weight = False +# for weight_name in spec_layer_weight_names: +# if weight_name in name: +# spec_layer_weight = True +# if weight_name in shared_weight_names: +# shared_weight = True +# break +# if not spec_layer_weight: +# # treat rest weights as weights for transformer layer block +# name = name.replace(f"model.layers.{spec_layer}.", +# f"model.layers.{spec_layer}.mtp_block.") +# elif shared_weight: +# # treat shared weights as top level weights +# name = name.replace(f"model.layers.{spec_layer}.", "model.") +# return name + diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py new file mode 100644 index 0000000..8298e74 --- /dev/null +++ b/vllm/model_executor/models/deepseek_v2.py @@ -0,0 +1,1128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV2/DeepseekV3 model.""" +import os +import re +import vllm.envs as envs + +import typing +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import MixtureOfExperts, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm import _custom_ops as ops +from vllm.utils import W8a8GetCacheJSON + +os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0') +class DeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x, + rms_weight: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + update_hd: Optional[bool] = False + ): + if envs.USE_FUSED_RMS_QUANT: + gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x, new_resi + else: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV2MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + routed_scaling_factor=self.routed_scaling_factor) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_experts", + ) + from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce + self.tbo_all_reduce = tbo_all_reduce + + def forward(self, hidden_states: torch.Tensor, + rms_weight: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None + ) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + if envs.USE_FUSED_RMS_QUANT: + shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) + else: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if shared_output is not None: + if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: + final_hidden_states = final_hidden_states + shared_output + else: + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) + + if self.tp_size > 1: + if envs.VLLM_ENABLE_TBO: + final_hidden_states = self.tbo_all_reduce(final_hidden_states) + else: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + if envs.USE_FUSED_RMS_QUANT: + return final_hidden_states.view(num_tokens, hidden_dim), new_resi + else: + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV2Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + attn_output = self.attn(q, k, v) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV2MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + if envs.USE_FUSED_RMS_QUANT: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + eps=config.rms_norm_eps, + prefix=f"{prefix}.q_a_proj") + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + eps=config.rms_norm_eps, + prefix=f"{prefix}.q_b_proj") + else: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + kv_b_proj=self.kv_b_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + rms_weight: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: + if self.q_lora_rank is not None: + q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) + q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False) + + else: + q = self.q_proj(hidden_states)[0] + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_local_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0], new_residual + else: + if self.q_lora_rank is not None: + q_c = self.q_a_proj(hidden_states)[0] + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] + else: + q = self.q_proj(hidden_states)[0] + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_local_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + if model_config.use_mla: + attn_cls = DeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.dpsk_fp16_quick = os.environ.get('DPSK_FP16_QUICK') == '1' + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if envs.USE_FUSED_RMS_QUANT: + # Fix residual FP16 overflow + residual_fix_overflow = False + + assert self.input_layernorm.has_weight is True + if residual is None: + residual = hidden_states + hidden_states, _ = self.self_attn( + positions = positions, + hidden_states = hidden_states, + rms_weight = self.input_layernorm.weight.data, + residual = None + ) + residual_fix_overflow = True + else: + hidden_states, new_residual = self.self_attn( + positions = positions, + hidden_states = hidden_states, + rms_weight = self.input_layernorm.weight.data, + residual = residual + ) + residual = new_residual + + if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0 or residual_fix_overflow: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual) + + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + return hidden_states, new_resi + + else: + # Self Attention + # Fix residual FP16 overflow + residual_fix_overflow = False + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + residual_fix_overflow = True + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0 or residual_fix_overflow: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + return hidden_states, residual + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + enable_eplb=enable_eplb, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.quant_method = None + if quant_config is not None: + self.quant_method = quant_config.get_name() + os.environ['LLAMA_NN'] = '0' + os.environ['LM_NN'] = '0' + + self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1' + self.config = config + self.quant_config = quant_config + + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' + self.tritonsingleton= W8a8GetCacheJSON() + self.tritonsingleton.topk = config.num_experts_per_tok + self.tritonsingleton.quant_method=self.quant_method + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def restore_qzeros_tensor(self, qzeros, qscales): + + low_bits = qzeros & 0x0F + high_bits = qzeros >> 4 + + zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1]) + zeors_int16 = zeors_tensor.to(torch.int16) + assert zeors_int16.shape == qscales.shape + + uint16_tensor1 = zeors_int16.view(torch.uint16) + uint16_tensor2 = qscales.view(torch.uint16) + + uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16 + uint32_tensor2 = uint16_tensor2.to(torch.int32) + + result_tensor = uint32_tensor1 + uint32_tensor2 + result_tensor =result_tensor.view(torch.uint32) + result_tensor = result_tensor.transpose(1, 2).contiguous() + return result_tensor + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None: + lay_key_words = [ + "self_attn.q_proj.weight", + "self_attn.q_a_proj.weight", + "self_attn.q_b_proj.weight", + "self_attn.kv_a_proj_with_mqa.weight", + "self_attn.kv_b_proj.weight", + "self_attn.o_proj.weight", + "mlp.gate_up_proj.weight", + "mlp.down_proj.weight", + "mlp.gate.weight", + "shared_experts.gate_up_proj.weight", + "shared_experts.down_proj.weight", + "lm_head.weight", + ] + + combined_words = "|".join(lay_key_words) + + for layername in loaded_params: + weight = params_dict[layername] + matches = re.findall(combined_words, layername) + if matches: + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1],-1) + + return loaded_params + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py new file mode 100644 index 0000000..1c240ae --- /dev/null +++ b/vllm/model_executor/models/deepseek_v3.py @@ -0,0 +1,850 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV3 model.""" +import os +import re +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +import vllm.envs as envs +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig, ParallelConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm import _custom_ops as ops + + +class DeepseekV3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV3MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias,) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV3MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce + self.tbo_all_reduce = tbo_all_reduce + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + if envs.VLLM_ENABLE_TBO: + final_hidden_states = self.tbo_all_reduce(final_hidden_states) + else: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV3Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.use_normal_rope = False + else: + self.use_normal_rope = True + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + if self.use_normal_rope: + seq_len = positions.size(0) + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + q_pe = q_pe.reshape(seq_len, -1) + k_pe = k_pe.reshape(seq_len, -1) + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if self.use_normal_rope: + q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape) + + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV3MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, + attn_metadata) + + +class DeepseekV3DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + if model_config.use_mla: + attn_cls = DeepseekV3MLAAttention + else: + attn_cls = DeepseekV3Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV3MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV3MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +# TODO(simon): check whether we support torch compile for Deepseek V3 +# @support_torch_compile +class DeepseekV3Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV3DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV3ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.parallel_config = vllm_config.parallel_config + + self.model = DeepseekV3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"),) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # TODO(simon): support nextn predict layers + if hasattr(self.config, "num_nextn_predict_layers" + ) and self.config.num_nextn_predict_layers > 0: + assert self.config.num_nextn_predict_layers == 1 + layer_idx = self.config.num_hidden_layers + if name.startswith(f"model.layers.{layer_idx}"): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None: + lay_key_words = [ + "self_attn.q_a_proj.weight", + "self_attn.kv_a_proj_with_mqa.weight", + "mlp.gate.weight", + "mlp.gate_up_proj.weight", + "mlp.down_proj", + "shared_experts.gate_up_proj", + "shared_experts.down_proj", + "self_attn.q_proj.weight", + "self_attn.q_b_proj.weight", + "self_attn.kv_b_proj.weight", + "self_attn.o_proj.weight", + "lm_head.weight" + ] + + combined_words = "|".join(lay_key_words) + + for layername in loaded_params: + weight = params_dict[layername] + matches = re.findall(combined_words, layername) + if matches: + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1],-1) + + return loaded_params diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py new file mode 100644 index 0000000..a9654f5 --- /dev/null +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -0,0 +1,660 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py +"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, MultiModalHashes, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, + MlpProjectorConfig, + VisionEncoderConfig) +from vllm.transformers_utils.processors.deepseek_vl2 import ( + DeepseekVLV2Processor) +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils import is_list_of + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +# The image token id may be various +_IMAGE_TOKEN = "" + + +class DeepseekVL2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + """ + images_spatial_crop: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + """ + + +class DeepseekVL2VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs, + DeepseekVL2VImageEmbeddingInputs] + + +class MlpProjector(nn.Module): + + def __init__(self, cfg: MlpProjectorConfig): + + super().__init__() + + self.cfg = cfg + assert not cfg.token_pooling, ( + "Token pooling is not supported currently.") + + if cfg.projector_type == "downsample_mlp_gelu": + mlp_depth = cfg.depth + mlp_ratio = cfg.mlp_ratio + modules = [ + nn.Linear( + cfg.input_dim * cfg.downsample_ratio * + cfg.downsample_ratio, cfg.n_embed * mlp_ratio) + ] + for _ in range(1, mlp_depth - 1): + modules.append(nn.GELU()) + modules.append( + nn.Linear(cfg.n_embed * mlp_ratio, + cfg.n_embed * mlp_ratio)) + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise NotImplementedError( + f"Unsupported projector type: {cfg.projector_type}") + + self.layers = modules + + def forward(self, x): + bs, hw, input_dim = x.shape + h = w = int((hw)**0.5) + """compute padding""" + if h % self.cfg.downsample_ratio: + pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio + else: + pad = 0 + x = x.reshape(bs, h, w, input_dim) + if pad > 0: + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) + """4 to 1 concat""" + x = x.permute(0, 3, 1, 2) # B, C, H, W + x = F.unfold(x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0) # B, C*4, HW // 4 + x = x.permute(0, 2, 1) + + return self.layers(x) + + +class DeepseekVL2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(DeepseekVLV2Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(DeepseekVLV2Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_num_image_tokens(self, + *, + image_width: int, + image_height: int, + cropping: bool = True) -> int: + hf_processor = self.get_hf_processor() + image_size = hf_processor.image_size + patch_size = hf_processor.patch_size + downsample_ratio = hf_processor.downsample_ratio + + if cropping: + best_width, best_height = hf_processor.select_best_resolution( + (image_width, image_height)) + num_width_tiles, num_height_tiles = (best_width // image_size, + best_height // image_size) + else: + num_width_tiles = num_height_tiles = 1 + + h = w = math.ceil((image_size // patch_size) / downsample_ratio) + + global_views_tokens = h * (w + 1) + local_views_tokens = (num_height_tiles * h) * (num_width_tiles * w + 1) + return global_views_tokens + local_views_tokens + 1 + + def get_image_size_with_most_features(self) -> ImageSize: + hf_config = self.get_hf_config() + candidate_resolutions = hf_config.candidate_resolutions + height, width = max(candidate_resolutions, + key=lambda x: self.get_num_image_tokens( + image_width=x[1], image_height=x[0])) + return ImageSize(width=width, height=height) + + +class DeepseekVL2DummyInputsBuilder( + BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + max_image_size = self.info.get_image_size_with_most_features() + + return { + "image": + self._get_dummy_images(width=max_image_size.width, + height=max_image_size.height, + num_images=num_images) + } + + +class DeepseekVL2MultiModalProcessor( + BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(prompt=prompt, **mm_data), + dict(**mm_kwargs, **tok_kwargs), + ) + pixel_values = processed_outputs["pixel_values"] + # split pixel values into patches corresponding to each image + images_spatial_crop = processed_outputs["images_spatial_crop"] + patches_per_image = [ + x.prod().item() + 1 for x in images_spatial_crop + ] + pixel_values = pixel_values.split(patches_per_image) + processed_outputs["pixel_values"] = pixel_values + else: + tokenizer = self.info.get_tokenizer() + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + images_spatial_crop=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + image_token_id = hf_processor.image_token_id + assert isinstance(image_token_id, int) + + def get_replacement_deepseek_vl2(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + cropping=len(images) <= 2, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_deepseek_vl2, + ) + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + # The processor logic is different for len(images) <= 2 vs > 2 + # Since the processing cache assumes that the processor output is + # invariant of how many images are passed per prompt, we only + # perform caching for the most common case + if mm_data_items.get_count("image", strict=False) > 2: + return self._apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes, + ) + + return super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + DeepseekVL2MultiModalProcessor, + info=DeepseekVL2ProcessingInfo, + dummy_inputs=DeepseekVL2DummyInputsBuilder) +class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "language.": "language_model.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: DeepseekVLV2Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_config = config.vision_config + self.projector_config = config.projector_config + self.text_config = config.text_config + + model_config = vllm_config.model_config + tokenizer = cached_tokenizer_from_config(model_config) + self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN] + + self.vision = self._init_vision_module(self.vision_config, + quant_config, + maybe_prefix(prefix, "vision")) + + self.projector = MlpProjector(self.projector_config) + self.tile_tag = config.tile_tag + self.global_view_pos = config.global_view_pos + + # special token for image token sequence format + embed_std = 1 / torch.sqrt( + torch.tensor(self.projector_config.n_embed, dtype=torch.float32)) + if self.tile_tag == "2D": + # <|view_separator|>, <|\n|> + self.image_newline = nn.Parameter( + torch.randn(self.projector_config.n_embed) * embed_std) + # This is a typo in original implementation + self.view_separator = nn.Parameter( + torch.randn(self.projector_config.n_embed) * embed_std) + else: + raise ValueError( + f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" + ) + + if self.text_config.topk_method == "noaux_tc": + architectures = ["DeepseekV3ForCausalLM"] + elif not self.text_config.use_mla: + architectures = ["DeepseekForCausalLM"] + else: + architectures = ["DeepseekV2ForCausalLM"] + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.text_config, + prefix=maybe_prefix(prefix, "language"), + architectures=architectures, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _init_vision_module( + self, + vision_config: VisionEncoderConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> nn.Module: + # TODO: refactor vision model through timm wrapper from transformers + try: + import timm + except ImportError: + raise ImportError("Please install timm") from ImportError + + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + model = model.to(dtype=torch.get_default_dtype()) + return model + + def _validate_pixel_values( + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: + + h = w = self.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_images_spatial_crop( + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: + expected_dims = 2 + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + images_spatial_crop = kwargs.pop("images_spatial_crop", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(images_spatial_crop, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(images_spatial_crop)}") + + return DeepseekVL2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(flatten_bn(pixel_values)), + images_spatial_crop=self._validate_images_spatial_crop( + flatten_bn(images_spatial_crop, concat=True))) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return DeepseekVL2VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _pixel_values_to_embedding( + self, + pixel_values: NestedTensors, + images_spatial_crop: torch.Tensor, + ) -> NestedTensors: + # Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width] + total_tiles = [x for x in pixel_values] + + # [batch_all_tiles, 3, height, width] + total_tiles = torch.cat(total_tiles, dim=0) + + # [batch_all_tiles, vit_seq_len, c] + images_feature = self.vision.forward_features(total_tiles) + + # [batch_all_tiles, hw, D] + images_embeds = self.projector(images_feature) + + _, hw, n_dim = images_embeds.shape + h = w = int(hw**0.5) + + # fill image token based on self.tile_tag & self.global_view_pos + tile_index = 0 + vision_embeddings = [] + for jdx in range(images_spatial_crop.size(0)): + # extra global & local features + num_width_tiles, num_height_tiles = images_spatial_crop[jdx] + if num_width_tiles == 0 or num_height_tiles == 0: + break + num_tiles_in_image = num_width_tiles * num_height_tiles + + # [hw, D] + global_features = images_embeds[tile_index] + + # [num_height_tiles * num_width_tiles, hw, D] + local_features = images_embeds[tile_index + 1:tile_index + 1 + + num_tiles_in_image] + tile_index += num_tiles_in_image + 1 + + # format global and local features + # ----------------- global view add newline ----------------- + # [hw, D] -> [h, w, D] + global_features = global_features.view(h, w, n_dim) + + # [D] -> [h, 1, D] + new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) + + # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] + global_features = torch.cat([global_features, new_lines_in_global], + dim=1) + + # [h, w + 1, D] -> [h * (w + 1), D] + global_features = global_features.view(-1, n_dim) + + # ----------------- local view add newline ----------------- + # [num_height_tiles * num_width_tiles, h * w, D] -> + # [num_height_tiles * h, num_width_tiles * w, D] + local_features = rearrange(local_features, + "(th tw) (h w) d -> (th h) (tw w) d", + th=num_height_tiles, + tw=num_width_tiles, + h=h, + w=w) + + # [D] -> [num_height_tiles * h, 1, D] + new_lines_in_local = repeat(self.image_newline, + "d -> (th h) 1 d", + th=num_height_tiles, + h=h) + + # [num_height_tiles * h, num_width_tiles * w + 1, D] + local_features = torch.cat([local_features, new_lines_in_local], + dim=1) + + # [num_height_tiles * h, num_width_tiles * w + 1, D] + # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] + local_features = local_features.view(-1, n_dim) + + # merge global and local tiles + if self.global_view_pos == "head": + global_local_features = torch.cat([ + global_features, + self.view_separator[None, :], + local_features, + ]) + else: + global_local_features = torch.cat([ + local_features, + self.view_separator[None, :], + global_features, + ]) + + vision_embeddings.append(global_local_features) + return vision_embeddings + + def _process_image_input( + self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors; " + "this can be either a list of 2D tensors or a single 3D tensor." + ) + + pixel_values = image_input["data"] + images_spatial_crop = image_input["images_spatial_crop"] + + return self._pixel_values_to_embedding( + pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.image_token_id) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object): + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + return autoloaded_weights diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py new file mode 100644 index 0000000..4bdcbfa --- /dev/null +++ b/vllm/model_executor/models/dots1.py @@ -0,0 +1,536 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The rednote-hilab team. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only dots1 model.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Dots1MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Dots1MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = (nn.Parameter( + torch.empty(config.n_routed_experts))) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = Dots1MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Dots1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + config: PretrainedConfig, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr(config, "head_dim", + hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + attention_bias = config.attention_bias + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=attention_bias, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward(self, positions: torch.Tensor, + hidden_states: torch.Tensor) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = self.q_norm(q.reshape(-1, self.num_heads, + self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, + self.head_dim)).reshape(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Dots1DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + + self.self_attn = Dots1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + config=config, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = Dots1MoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Dots1MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Dots1Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Dots1DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@support_torch_compile +class Dots1ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Dots1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py new file mode 100644 index 0000000..c551ecd --- /dev/null +++ b/vllm/model_executor/models/eagle.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import maybe_prefix + +logger = init_logger(__name__) + + +class DummyInputLayerNorm(nn.Module): + + def __init__(self, weight=None, bias=None): + super().__init__() + self.weight = nn.Parameter(weight) if weight is not None else None + self.bias = nn.Parameter(bias) if bias is not None else None + + def forward(self, x): + return x + + +class DummyOutputNorm(nn.Module): + + def forward(self, x, residual): + if residual is None: + return x + else: + return x + residual, None + + +class EAGLE(nn.Module): + """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 + Reference implementation: https://github.com/SafeAILab/EAGLE + + Differences from reference implementation: + 1. In reference, LlamaDecoderLayer implementation doesn't have + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). + Following this approach, our implementation also disables + the input_layernorm for the first decoder layer. + 2. We allow any decoder layer to be used in EAGLE whereas in reference + decoder layer is fixed to be LlamaDecoderLayer. + 3. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute. + 4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP + module with regards to the use of additional RMS norms. The original + EAGLE architecture 1) skips the pre-attention norm in its first + transformer block, and 2) skips the final output norm, both of which we + found to be suboptimal. We also add the support for separate norms + applying to both the token embedding and hidden states before projection + as in DeepSeek MTP, which we found to improve performance as well. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.dtype = vllm_config.model_config.dtype + self.config = config + + architectures = getattr(self.config.model, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + + self.model = model_cls(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.fc = nn.Linear(config.model.hidden_size * 2, + config.model.hidden_size, + bias=getattr(self.config, "eagle_fc_bias", False)) + + # Modify layer normalization and residual connections as suggested + # in the EAGLE framework: https://github.com/SafeAILab/EAGLE + # While weights and biases are generally not needed, + # they are retained here to support certain unit tests + # (e.g., spec_decode/e2e/test_eagle_correctness.py). + if not hasattr(self.config.model, + "skip_prenorm") or self.config.model.skip_prenorm: + self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( + weight=self.model.model.layers[0].input_layernorm.weight) + + if not hasattr( + self.config.model, + "skip_output_norm") or self.config.model.skip_output_norm: + self.model.model.norm = DummyOutputNorm() + + self.add_para_norm = False + if hasattr(self.config.model, + "add_para_norm") and self.config.model.add_para_norm: + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.add_para_norm = True + + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + + # Handle both empty previous_hidden_states + # and mismatched batch size + batch_size = inputs_embeds.size(0) + if previous_hidden_states.size(0) == 0 or \ + previous_hidden_states.size(0) != batch_size: + hidden_dim = self.config.model.hidden_size + device = inputs_embeds.device + # Create zero tensor with matching batch size + previous_hidden_states = \ + torch.zeros(batch_size, hidden_dim, device=device) + + if self.add_para_norm: + inputs_embeds = torch.cat([ + self.enorm(inputs_embeds), + self.hnorm(previous_hidden_states) + ], + dim=-1) + else: + inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states], + dim=-1) + + inputs_embeds = self.fc(inputs_embeds) + + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + + hidden_states = self.model.model( + input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + if self.token_map is not None: + _logits = logits + logits = -torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype) + + logits[..., self.token_map] = _logits + + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B + # due to missing lm_head weights and its config being that of a + # Llama model. Here's a compatible version with the same weights: + # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm + # Also, here's an example script for converting trained EAGLE + # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d + model_weights = {} + for name, loaded_weight in weights: + if name == "token_map": + if self.config.truncated_vocab_size < self.config.vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name.startswith("fc.weight"): + weight_loader = getattr(self.fc.weight, "weight_loader", + default_weight_loader) + weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("fc.bias"): + if self.fc.bias is not None: + weight_loader = getattr(self.fc.bias, "weight_loader", + default_weight_loader) + weight_loader(self.fc.bias, loaded_weight) + else: + logger.warning_once("Found bias in the loaded weights but " + "the model config doesn't have bias.") + elif name.startswith("enorm.weight"): + weight_loader = getattr(self.enorm.weight, "weight_loader", + default_weight_loader) + weight_loader(self.enorm.weight, loaded_weight) + elif name.startswith("hnorm.weight"): + weight_loader = getattr(self.hnorm.weight, "weight_loader", + default_weight_loader) + weight_loader(self.hnorm.weight, loaded_weight) + elif name.startswith("model.lm_head.") or name.startswith( + "model.model."): + model_weights[name.split("model.", 1)[-1]] = loaded_weight + elif name.startswith("lm_head.") or name.startswith("model."): + model_weights[name] = loaded_weight + else: + model_weights[f"model.{name}"] = loaded_weight + + if "lm_head.weight" in model_weights: + lm_head_weight = model_weights.pop("lm_head.weight") + + if self.token_map is not None and\ + lm_head_weight.shape[0] > self.token_map.shape[0]: + + lm_head_weight = lm_head_weight[self.token_map] + + else: + # NOTE(Shangming): initialize the placeholder for lm_head weight. + lm_head_weight = torch.zeros( + self.lm_head.org_vocab_size, + self.lm_head.embedding_dim, + dtype=self.dtype, + ) + + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, lm_head_weight) + + self.model.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/ernie45.py b/vllm/model_executor/models/ernie45.py new file mode 100644 index 0000000..fcc7a1f --- /dev/null +++ b/vllm/model_executor/models/ernie45.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine model compatible with HuggingFace weights.""" +from vllm.config import VllmConfig +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import PPMissingLayer + + +class Ernie4_5_ForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Hack Llama model to fit HF format Ernie4.5 dense implementation + # Attention difference between Ernie and Llama: + # 1. rotary_dim and no Neox style. + # 2. There is no bias for o_proj in attention + for layer in self.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.is_neox_style = False + layer.self_attn.o_proj.bias = None + layer.self_attn.o_proj.skip_bias_add = True \ No newline at end of file diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py new file mode 100644 index 0000000..d36da97 --- /dev/null +++ b/vllm/model_executor/models/ernie45_moe.py @@ -0,0 +1,581 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only ErineMoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Ernie4_5_MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + use_bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Ernie4_5_MoeMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", + None) + + if self.tp_size > config.moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.moe_num_experts}.") + self.gate = ReplicatedLinear(config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=config.moe_num_experts, + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + if self.moe_num_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.moe_num_shared_experts) + self.shared_experts = Ernie4_5_MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=self.experts.must_reduce_shared_expert_outputs( + )) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + if self.moe_num_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + router_logits, _ = self.gate(hidden_states) + + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.moe_num_shared_experts is not None and \ + shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Ernie4_5_MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + self.self_attn = Ernie4_5_MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, 'head_dim', None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'use_bias', False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_num_experts = getattr(config, "moe_num_experts", 0) + moe_layer_start_index = getattr(config, "moe_layer_start_index", 0) + moe_layer_end_index = getattr(config, "moe_layer_end_index", + config.num_hidden_layers - 1) + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", moe_num_experts > 0) + + if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index + and layer_idx <= moe_layer_end_index): + self.mlp = Ernie4_5_MoeMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Ernie4_5_MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Ernie4_5_MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.moe_num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + continue + # MTP will be supported soon. + if "mtp" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py new file mode 100644 index 0000000..aaf105e --- /dev/null +++ b/vllm/model_executor/models/exaone.py @@ -0,0 +1,551 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py +# Copyright 2024 The LG U+ CTO AI Tech Lab. +# Copyright 2021 The LG AI Research EXAONE Lab +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.exaone import ExaoneConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class ExaoneGatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneBlockAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ExaoneAttention( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=f"{prefix}.attention", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.attention( + positions=positions, + hidden_states=hidden_states, + ) + + +class ExaoneDecoderLayer(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.attn = ExaoneBlockAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ExaoneModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.wte = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.wte = PPMissingLayer() + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.h", + ) + if get_pp_group().is_last_rank: + self.ln_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + else: + self.ln_f = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "c_fc_0", + "c_fc_1", + ], + } + + # LoRA specific attributes + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.transformer = ExaoneModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py new file mode 100644 index 0000000..d78ee10 --- /dev/null +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# Copyright 2024 Meta Platforms, Inc. and affiliates. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama model for fairseq2 weights.""" + +from collections.abc import Iterable + +import torch +from torch.nn import Parameter + +from vllm.config import VllmConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import set_weight_attrs +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import AutoWeightsLoader, WeightsMapper + + +class Fairseq2LlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + # For the model loader to read only the relevant checkpoint files + self.allow_patterns_overrides = [ + # either the full checkpoint + "model.pt", + # or the tp-sharded checkpoint of the current rank + f"model.{self.tp_rank}.pt", + ] + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + # fairseq2's serialization adds a wrapper to usual .pt state_dict's: + # { "model_key": my_model_name, "my_model_name": state_dict } + # which we first need to unpack + weights_wrapped = dict(weights) + weights = weights_wrapped[ + weights_wrapped["model_key"]].items() # type: ignore + + # remap keys + fs2_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder_frontend.embed.": "model.embed_tokens.", + "decoder.": "model.", + "final_proj.": "lm_head.", + }, + orig_to_new_substr={ + ".self_attn_layer_norm.": ".input_layernorm.", + ".ffn_layer_norm.": ".post_attention_layernorm.", + ".self_attn.output_proj.": ".self_attn.o_proj.", + ".ffn.gate_proj.": ".mlp.gate_proj.", + ".ffn.inner_proj.": ".mlp.up_proj.", + ".ffn.output_proj.": ".mlp.down_proj.", + ".layer_norm.": ".norm.", + }, + ) + weights = fs2_to_vllm_mapper.apply(weights) + + params = dict(self.named_parameters()) + + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights( + (self.reshape_fairseq2_weights(name, loaded_weight, params) + for name, loaded_weight in weights)) + + def flag_sharded_weights(self, params: dict[str, Parameter]): + """Sets the `is_sharded_weight` flag to True for all sharded weights""" + for name, param in params.items(): + modules = name.split(".") + if "norm" in name and len(param.size()) < 2: + # layer norms are not sharded + continue + elif any(emb in modules for emb in ["embed_tokens", "lm_head"]): + # for now we repeat embedding layers for compatibility + continue + else: + # all other layers are sharded + set_weight_attrs(param, {"is_sharded_weight": True}) + + def reshape_fairseq2_weights( + self, + name: str, + loaded_weight: torch.Tensor, + params: dict[str, Parameter], + ) -> tuple[str, torch.Tensor]: + """Reshape fairseq2's weights.""" + + def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: + attn_in = self.config.head_dim * n_heads + # check for a sharded weight on dim 0 + if attn_in // self.tp_size == w.size()[0]: + attn_in //= self.tp_size + n_heads //= self.tp_size + attn_out = self.config.hidden_size + return (w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, + 2).reshape(attn_in, attn_out)) + + modules = name.split(".") + + # rotary embeds should be sliced + if "k_proj" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + + elif "q_proj" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + # We make the loaded weights compatible with both + # full checkpoints and tp sharded checkpoints. + # Embeddings are repeated to fit the vocab size. + # Other weights are flagged for the weight_loader calls. + if any(emb in modules for emb in ["embed_tokens", "lm_head"]): + # Embeddings are sharded on dim 0 + dim = 0 + # In fairseq2, vocab size has to be divisible by tp_size + # so we don't worry about padding + if self.tp_size > 1 and loaded_weight.shape[ + dim] < self.config.vocab_size: + assert loaded_weight.shape[ + dim] * self.tp_size == self.config.vocab_size, \ + "vocab_size should be divisible by tp_size." + repeats = [1] * len(loaded_weight.size()) + repeats[dim] = self.tp_size + # repeat to match vocab size and to be easily 'narrow'able + loaded_weight = loaded_weight.repeat(repeats) + set_weight_attrs(params[name], {"is_sharded_weight": False}) + # if embeddings are sharded, the rest is too + if "embed_tokens" in modules: + self.flag_sharded_weights(params) + + return name, loaded_weight diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py new file mode 100644 index 0000000..968bc7d --- /dev/null +++ b/vllm/model_executor/models/falcon.py @@ -0,0 +1,564 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import math +from collections.abc import Iterable +from typing import Optional, Union + +import os +import re +import torch +from torch import nn +from torch.nn import LayerNorm +from transformers import FalconConfig as HF_FalconConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import RWConfig + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +from vllm import _custom_ops as ops +from vllm.model_executor.utils import pad_weight, gemm_bank_conf + +FalconConfig = Union[HF_FalconConfig, RWConfig] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + + return slopes + + +class FalconAttention(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + + if self.new_decoder_architecture: + self.total_num_kv_heads = config.num_kv_heads + elif self.multi_query: + self.total_num_kv_heads = 1 + else: + self.total_num_kv_heads = self.total_num_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + reduce_results=self.reduce_row_parallel_results) + + self.use_rotary = config.rotary + self.use_alibi = config.alibi + assert not (self.use_rotary and self.use_alibi), ( + "Rotary and alibi are mutually exclusive.") + + if self.use_rotary: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + prefix=f"{prefix}.attn") + elif self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * + self.inv_norm_factor) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn") + else: + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, bias = self.query_key_value(hidden_states) + # if os.environ.get('FA_PAD') == '1' and self.quant_method is None: + # qkv = qkv[...,:-32] + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_rotary: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + attn_output, bias = self.dense(attn_output) + return attn_output, bias + + +class FalconMLP(nn.Module): + + def __init__( + self, + config: FalconConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config) + self.act = get_act_fn("gelu") + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + bias=config.bias, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results, + quant_config=quant_config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. + x, bias = self.dense_h_to_4h(x) + if bias is not None: + x += bias + x = self.act(x) + x, bias = self.dense_4h_to_h(x) + return x, bias + + +class FalconDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = FalconAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") + self.mlp = FalconMLP(config, quant_config) + self.config = config + + if (not hasattr(config, "num_ln_in_parallel_attn")): + config.num_ln_in_parallel_attn = None + + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + + if self.config.num_ln_in_parallel_attn == 2: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention( + positions=positions, + hidden_states=attention_layernorm_out, + ) + if self.reduce_row_parallel_results and attention_bias is not None: + attention_output += attention_bias + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual += attention_output + mlp_layernorm_out = self.post_attention_layernorm(residual) + + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + + # MLP. + mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) + if self.reduce_row_parallel_results and mlp_bias is not None: + mlp_output += mlp_bias + + if not self.reduce_row_parallel_results: + # When MLP and Attention layers are parallel, we can use + # only one all-reduce operator to reduce the results from + # both MLP and Attention layers. + mlp_output += attention_output + mlp_output = tensor_model_parallel_all_reduce(mlp_output) + if attention_bias is not None: + mlp_output += attention_bias + if mlp_bias is not None: + mlp_output += mlp_bias + + output = mlp_output + residual + return output + + +@support_torch_compile +class FalconModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + + # Transformer blocks + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: FalconDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + + # Final Layer Norm + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' + self.use_fa_pad = os.environ.get('FA_PAD') == '1' + self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + total_num_heads = self.config.num_attention_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + if "query_key_value" in name: + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.use_llama_nn and self.quant_method is None : + lay_key_words = [ + "self_attention.query_key_value.weight", + "self_attention.dense.weight", + "mlp.dense_h_to_4h.weight", + "mlp.dense_4h_to_h.weight", + ] + combined_words = "|".join(lay_key_words) + + # lay_qkv_words = ["self_attention.query_key_value.weight"] + # qkv_words = "|".join(lay_qkv_words) + + for layername in loaded_params: + weight = params_dict[layername] + matches = re.findall(combined_words, layername) + if matches: + # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + # if self.use_fa_pad and (re.findall(qkv_words, layername)): + # if not gemm_bank_conf(weight.data.shape[0]): + # weight.data = pad_weight(weight.data, 32) + + _weight = torch.zeros_like(weight.data) + ori_shape =_weight.shape + + ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) + weight.data.copy_(_weight) + + weight.data=weight.data.reshape(ori_shape[1], -1) + return loaded_params + + +class FalconForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = FalconModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py new file mode 100644 index 0000000..a76e1f2 --- /dev/null +++ b/vllm/model_executor/models/falcon_h1.py @@ -0,0 +1,708 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only FalconH1 model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import FalconH1Config + +from vllm import envs +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class FalconH1MLP(nn.Module): + + def __init__( + self, + config: FalconH1Config, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.intermediate_size = config.intermediate_size + self.gate_multiplier, self.down_multiplier = config.mlp_multipliers + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier + x = self.act_fn(x) + x, _ = self.down_proj(x) + x = x * self.down_multiplier + return x + + +class FalconH1SSMDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + + self.d_ssm = (int(config.mamba_expand * config.hidden_size) + if config.mamba_d_ssm is None else config.mamba_d_ssm) + + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=self.d_ssm, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config, + use_rms_norm=config.mamba_rms_norm, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size, + ) + # n_groups is overridden later by `MambaMixer2` + self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state + self.zxbcdt_multipliers = config.ssm_multipliers + self._init_mup_vector() + + def _init_mup_vector(self): + """ + Non learnable per-block scaling vector composed of element-wise + multipliersapplied to each separate contiguous block of the output + of the linear projection (in_proj) before further processing + (gating, convolution, SSM): + + - Z block: [0 : d_ssm] → zxbcdt_multipliers[0] + - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1] + - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2] + - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] + → zxbcdt_multipliers[3] + - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4] + + where: + - d_ssm: Dimension of state-space model latent + - G: Number of groups (n_groups) + - S: SSM state size per group + - All indices are divided by tp_size to support tensor parallelism + """ + vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size + + self.config.mamba_n_heads) // self.tp_size + mup_vector = torch.ones(1, vector_shape) + # Z vector 0 -> d_ssm + mup_vector[:, :self.d_ssm // + self.tp_size] *= self.zxbcdt_multipliers[0] + # X vector d_ssm -> 2 * d_ssm + mup_vector[:, + (self.d_ssm // + self.tp_size):(2 * self.d_ssm // + self.tp_size)] *= self.zxbcdt_multipliers[1] + # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm) // + self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) // + self.tp_size, + ] *= self.zxbcdt_multipliers[2] + # C vector 2 * d_ssm + (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm + self.groups_time_state_size) // + self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) // + self.tp_size, + ] *= self.zxbcdt_multipliers[3] + # dt vector 2 * d_ssm + 2 * (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads + mup_vector[ + :, + (2 * self.d_ssm + 2 * self.groups_time_state_size) // + self.tp_size:, + ] *= self.zxbcdt_multipliers[4] + + self.register_buffer("mup_vector", mup_vector, persistent=False) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + hidden_states = self.mamba( + hidden_states, + mamba_cache_params, + mamba2_metadata=mamba2_metadata, + mup_vector=self.mup_vector, + ) + return hidden_states, residual + + +class FalconH1AttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + rope_theta = getattr(config, "rope_theta", 1e11) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = (config.hidden_size // self.total_num_heads if getattr( + config, "head_dim", None) is None else config.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if hasattr(config, "partial_rotary_factor"): + rotary_dim = self.head_dim * config.partial_rotary_factor + elif hasattr(config, "attn_rotary_emb"): + rotary_dim = config.attn_rotary_emb # for backward compatibility + else: + rotary_dim = self.head_dim # default + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + rope_scaling=rope_scaling, + base=rope_theta, + is_neox_style=True, + dtype=None, # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.key_multiplier = config.key_multiplier + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + k = k * self.key_multiplier + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + ) + return hidden_states, residual + + +class FalconH1ParallelHybrid(nn.Module): + """ + A hybrid decoder layer for FalconH1 where the input is processed + in parallel through both the self-attention branch and the SSM (Mamba) + branch. Their outputs are then summed to produce the final hidden state. + + This layer uses: + - FalconH1AttentionDecoderLayer for the multi-head self-attention branch. + - FalconH1SSMDecoderLayer for the state-space (Mamba) branch. + """ + + def __init__( + self, + config: FalconH1Config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + # Instantiate the attention branch + self.self_attn = FalconH1AttentionDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + # In V1 all attention/ssm layers must have + # different index in prefix + ssm_layer_idx = config.num_hidden_layers + layer_idx + ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}" + + # Instantiate the SSM branch + self.mamba = FalconH1SSMDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=ssm_prefix, + ) + self.ssm_out_multiplier = config.ssm_out_multiplier + self.ssm_in_multiplier = config.ssm_in_multiplier + + self.attention_in_multiplier = config.attention_in_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.feed_forward = FalconH1MLP(config) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Process input through the attention branch. + # FalconH1AttentionDecoderLayer expects positions, hidden_states, + # kv_cache, attn_metadata, and residual. + attn_hidden, _ = self.self_attn( + positions=positions, + hidden_states=hidden_states * self.attention_in_multiplier, + residual=residual, + **kwargs, + ) + + # Process input through the SSM branch. + # FalconH1SSMDecoderLayer expects hidden_states, attn_metadata, + # residual, mamba_cache_params, and sequence_idx. + ssm_hidden, _ = self.mamba( + hidden_states=hidden_states * self.ssm_in_multiplier, + residual=residual, + mamba_cache_params=mamba_cache_params, + mamba2_metadata=mamba2_metadata, + **kwargs, + ) + # Sum the outputs from both branches. + # We assume both branches produce outputs of the same + # dimensionality (config.hidden_size). + hidden_states = (attn_hidden * self.attn_out_multiplier) + ( + ssm_hidden * self.ssm_out_multiplier) + hidden_states = hidden_states + residual + + # feed-forward + residual = hidden_states + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class FalconH1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: FalconH1Config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank: + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + else: + self.embed_tokens = PPMissingLayer() + self.embedding_multiplier = 1.0 + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = FalconH1ParallelHybrid + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + else: + self.final_layernorm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + attn_metadata = get_forward_context().attn_metadata + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds * self.embedding_multiplier + else: + hidden_states = (self.get_input_embeddings(input_ids) * + self.embedding_multiplier) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + layer_mamba_cache_params = None + if mamba_cache_params: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i) + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + }) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + assert (not cache_config.enable_prefix_caching + ), "FalconH1 currently does not support prefix caching" + + self.quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = FalconH1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.tie_word_embeddings = config.tie_word_embeddings + self.unpadded_vocab_size = config.vocab_size + self.mamba_cache: Optional[MambaCacheManager] = None + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + ) + self.lm_head_multiplier = config.lm_head_multiplier + if self.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + # Used to track and store by the Mamba cache between steps. + + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=config.lm_head_multiplier, + ) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + self.mamba_cache = MambaCacheManager( + self.vllm_config, + self.lm_head.weight.dtype if hasattr( + self.lm_head, 'weight') else torch.bfloat16, + self.config.num_hidden_layers, + *self._get_mamba_cache_shape(), + ) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model( + input_ids, + positions, + mamba_cache_params, + intermediate_tensors, + inputs_embeds, + ) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> tuple[tuple[int, int], tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = (int(self.config.mamba_expand * + hidden_size) if self.config.mamba_d_ssm + is None else self.config.mamba_d_ssm) + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size) + + # - heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if "mamba" in name: + name = name.replace("mamba", "mamba.mamba") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if self.tie_word_embeddings and "lm_head" in name: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + if self.tie_word_embeddings: + loaded_params.add("lm_head.weight") + return loaded_params diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py new file mode 100644 index 0000000..1bedac2 --- /dev/null +++ b/vllm/model_executor/models/florence2.py @@ -0,0 +1,1113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections import OrderedDict +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BartTokenizer, BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, + BartParallelLMHead, + BartScaledWordEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, + SupportsV0Only) +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings + + +class Florence2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +# ViT implementation are all copied from +# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py +class LearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256, num_pos=50): + super().__init__() + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding( + num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values): + """ + pixel_values: (batch_size, height, width, num_channels) + returns: (batch_size, height, width, embedding_dim * 2) + """ + if len(pixel_values.shape) != 4: + raise ValueError('pixel_values must be a 4D tensor') + height, width = pixel_values.shape[1:3] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + # (height, width, embedding_dim * 2) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(height, 1, 1), + y_emb.unsqueeze(1).repeat(1, width, 1) + ], + dim=-1) + # (embedding_dim * 2, height, width) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + # (batch_size, embedding_dim * 2, height, width) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + # (batch_size, height, width, embedding_dim * 2) + pos = pos.permute(0, 2, 3, 1) + return pos + + +class PositionalEmbeddingCosine1D(nn.Module): + """ + This class implements a very simple positional encoding. It follows closely + the encoder from the link below: + https://pytorch.org/tutorials/beginner/translation_transformer.html + Args: + embed_dim: The dimension of the embeddings. + dropout_prob: The dropout probability. + max_seq_len: The maximum length to precompute the positional encodings. + """ + + def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: + super().__init__() + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + # Generate the sinusoidal arrays. + factor = math.log(10000) + denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / + self.embed_dim) + # Matrix where rows correspond to a positional embedding as a function + # of the position index (i.e., the row index). + frequencies = \ + torch.arange(0, self.max_seq_len) \ + .reshape(self.max_seq_len, 1) * denominator + pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) + # Populate uneven entries. + pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) + pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) + # Save the positional embeddings in a constant buffer. + # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, + requires_grad=False) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + """ + Args: + seq_embeds: The sequence embeddings in order. Allowed size: + 1. [T, D], where T is the length of the sequence, and D is the + frame embedding dimension. + 2. [B, T, D], where B is the batch size and T and D are the + same as above. + Returns a tensor of with the same dimensions as the input: i.e., + [1, T, D] or [T, D]. + """ + shape_len = len(seq_embeds.shape) + assert 2 <= shape_len <= 3 + len_seq = seq_embeds.size(-2) + assert len_seq <= self.max_seq_len + pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] + # Adapt pre-computed positional embeddings to the input. + if shape_len == 3: + pos_embeds = pos_embeds.view( + (1, pos_embeds.size(0), pos_embeds.size(1))) + return pos_embeds + + +class MySequential(nn.Sequential): + + def forward(self, *inputs): + for module in self._modules.values(): + if isinstance(inputs, tuple): + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class PreNorm(nn.Module): + + def __init__(self, norm, fn): + super().__init__() + self.norm = norm + self.fn = fn + + def forward(self, x, *args, **kwargs): + shortcut = x + if self.norm is not None: + x, size = self.fn(self.norm(x), *args, **kwargs) + else: + x, size = self.fn(x, *args, **kwargs) + + x = shortcut + x + + return x, size + + +class Mlp(nn.Module): + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.net = nn.Sequential( + OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features))])) + + def forward(self, x, size): + return self.net(x), size + + +class DepthWiseConv2d(nn.Module): + + def __init__( + self, + dim_in, + kernel_size, + padding, + stride, + bias=True, + ): + super().__init__() + self.dw = nn.Conv2d(dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias=bias) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == H * W + + x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + size = (x.size(-2), x.size(-1)) + x = x.flatten(2).transpose(1, 2) + return x, size + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None, + pre_norm=True): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + + dim_norm = in_chans if pre_norm else embed_dim + self.norm = norm_layer(dim_norm) if norm_layer else None + + self.pre_norm = pre_norm + + def forward(self, x, size): + H, W = size + if len(x.size()) == 3: + if self.norm and self.pre_norm: + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) + + x = self.proj(x) + + _, _, H, W = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm and not self.pre_norm: + x = self.norm(x) + + return x, (H, W) + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, groups=8, qkv_bias=True): + super().__init__() + + self.groups = groups + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, size): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, + C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * (float(N)**-0.5) + attention = q.transpose(-1, -2) @ k + attention = attention.softmax(dim=-1) + x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x, size + + +class ChannelBlock(nn.Module): + + def __init__(self, + dim, + groups, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.channel_attn = PreNorm( + norm_layer(dim), + ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.channel_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + + return x, size + + +def window_partition(x, window_size: int): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): + B = batch_size + + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = float(head_dim)**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, size): + + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition(x, self.window_size) + x = x.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # attn_windows = self.attn(x_windows) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + + # merge windows + x = x.view(-1, self.window_size, self.window_size, C) + x = window_reverse(x, B, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + return x, size + + +class SpatialBlock(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.window_attn = PreNorm( + norm_layer(dim), + WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.window_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + return x, size + + +class DaViT(nn.Module): + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(1, 1, 3, 1), + patch_size=(7, 2, 2, 2), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 0, 0, 0), + patch_prenorm=(False, False, False, False), + embed_dims=(64, 128, 192, 256), + num_heads=(3, 6, 12, 24), + num_groups=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + enable_checkpoint=False, + conv_at_attn=True, + conv_at_ffn=True, + ): + super().__init__() + + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_groups = num_groups + self.num_stages = len(self.embed_dims) + self.enable_checkpoint = enable_checkpoint + assert self.num_stages == len(self.num_heads) == len(self.num_groups) + + num_stages = len(embed_dims) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths) * 2) + ] + + depth_offset = 0 + convs = [] + blocks = [] + for i in range(num_stages): + conv_embed = ConvEmbed( + patch_size=patch_size[i], + stride=patch_stride[i], + padding=patch_padding[i], + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + norm_layer=norm_layer, + pre_norm=patch_prenorm[i]) + convs.append(conv_embed) + + block = MySequential(*[ + MySequential( + OrderedDict([('spatial_block', + SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset + j * 2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + )), + ('channel_block', + ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset + j * 2 + + 1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ))])) for j in range(depths[i]) + ]) + blocks.append(block) + depth_offset += depths[i] * 2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + + @property + def dim_out(self): + return self.embed_dims[-1] + + def forward_features_unpool(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + input_size = (x.size(2), x.size(3)) + for conv, block in zip(self.convs, self.blocks): + x, input_size = conv(x, input_size) + x, input_size = block(x, input_size) + return x + + def forward_features(self, x): + x = self.forward_features_unpool(x) + + # (batch_size, num_tokens, token_dim) + x = self.avgpool(x.transpose(1, 2)) + # (batch_size, 1, num_tokens) + x = torch.flatten(x, 1) + x = self.norms(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + @classmethod + def from_config(cls, config): + return cls( + depths=config.depths, + embed_dims=config.dim_embed, + num_heads=config.num_heads, + num_groups=config.num_groups, + patch_size=config.patch_size, + patch_stride=config.patch_stride, + patch_padding=config.patch_padding, + patch_prenorm=config.patch_prenorm, + drop_path_rate=config.drop_path_rate, + window_size=config.window_size, + ) + + +# Language backbone and processor implementation +class Florence2LanguageModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.vocab_size = config.vocab_size + + self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) + self.encoder = BartEncoder(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + self.decoder = BartDecoder(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + if self.config.tie_word_embeddings: + self.encoder.embed_tokens.weight = self.shared.weight + self.decoder.embed_tokens.weight = self.shared.weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if inputs_embeds is not None or encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions, + inputs_embeds=inputs_embeds) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states) + + return decoder_outputs + + +class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = Florence2LanguageModel(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + return self.model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.encoder.embed_tokens(input_ids) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + if self.config.tie_word_embeddings and "embed_tokens" in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Florence2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + processor_config = self.ctx.get_hf_image_processor_config() + return processor_config["image_seq_length"] + + +class Florence2DummyInputsBuilder( + BaseDummyInputsBuilder[Florence2ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width = target_height = self.info.get_hf_config().projection_dim + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class Florence2MultiModalProcessor( + EncDecMultiModalProcessor[Florence2ProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return [self.info.get_hf_config().eos_token_id] + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + hf_processor = self.info.get_hf_processor() + tokenizer: BartTokenizer = hf_processor.tokenizer + prompt_text = tokenizer.decode(prompt_tokens) + # convert task tokens to prompt + prompt_text = hf_processor._construct_prompts([prompt_text])[0] + prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False) + return prompt_tokens + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + else: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + prompt = hf_processor._construct_prompts([prompt])[0] + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + pad_token_id = hf_config.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Florence2MultiModalProcessor, + info=Florence2ProcessingInfo, + dummy_inputs=Florence2DummyInputsBuilder) +class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.vision_config + self.processor_config = processor_config + assert config.vision_config.model_type == 'davit', ( + 'only DaViT is supported for now') + self.vision_tower = DaViT.from_config(config=config.vision_config) + self._build_image_projection_layers(config) + self.language_model = Florence2LanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=f"{prefix}.language_model", + ) + self.pad_token_id = config.pad_token_id + + def _build_image_projection_layers(self, config: PretrainedConfig): + image_dim_out = config.vision_config.dim_embed[-1] + dim_projection = config.vision_config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection)) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.vision_config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings']) + else: + raise NotImplementedError("Florence2 only supports learned_abs_2d " + "as image position embedding.") + + self.image_feature_source = config.vision_config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = ( + self.vision_config.visual_temporal_embedding) + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config[ + 'max_temporal_embeddings']) + else: + raise NotImplementedError( + 'Florence2 only supports COSINE as temporal embedding.') + + def _validate_pixel_values( + self, data: Union[torch.Tensor, list[torch.Tensor]] + ) -> Union[torch.Tensor, list[torch.Tensor]]: + + size = self.processor_config["size"] + h, w = size["height"], size["width"] + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = tuple(*map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return Florence2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: + dtype = next(self.vision_tower.parameters()).dtype + pixel_values = pixel_values.to(dtype) + + batch_size, T = pixel_values.size(0), 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, ( + 'only support square feature maps for now') + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, + x.shape[-1]) + visual_temporal_embed.view( + 1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, + x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format( + _image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + def _process_image_input( + self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + return self._encode_image(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.pad_token_id) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + if encoder_input_ids.numel() > 0 or vision_embeddings is not None: + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fm9g.py b/vllm/model_executor/models/fm9g.py new file mode 100644 index 0000000..82fb9f9 --- /dev/null +++ b/vllm/model_executor/models/fm9g.py @@ -0,0 +1,592 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only FM9G model compatible with HuggingFace weights.""" +import math +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union, List + +import torch +from torch import nn +from vllm.transformers_utils.configs import FM9GConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class FM9GMoE(nn.Module): + """A tensor-parallel MoE implementation that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class FM9GMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_act_param: float, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act == "silu": + self.act_fn = SiluAndMul() + elif hidden_act == "fatrelu": + self.act_fn = FatreluAndMul(threshold=hidden_act_param) + else: + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu and fatrelu are supported for now.") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class FM9GAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + # set rope as fp32 instead of bf16 + self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + orig_dtype = q.dtype + q, k = q.float(), k.float() + q, k = self.rotary_emb(positions, q, k) + q, k = q.to(orig_dtype), k.to(orig_dtype) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class FM9GDecoderLayer(nn.Module): + + def __init__( + self, + config: FM9GConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + self.hidden_size = config.hidden_size + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.prefix = prefix + self._init_attn_block() + self._init_ffn_block() + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = FM9GAttention( + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", + ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = FM9GMLP( + hidden_size=self.hidden_size, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + quant_config=self.quant_config, + ) + else: + self.mlp = FM9GMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + return hidden_states, None + + +@support_torch_compile +class FM9GModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.num_experts = getattr(self.config, "num_experts", 0) + self._init_layers(prefix, config, cache_config, quant_config) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) + + def _init_layers( + self, + prefix: str, + config: FM9GConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: FM9GDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class FM9GForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.prefix = prefix + self.vllm_config = vllm_config + self.config = config + self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config + + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): + return FM9GModel(vllm_config=vllm_config, prefix=prefix) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + hidden_states = hidden_states / self.scale_width + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py new file mode 100644 index 0000000..26c8f80 --- /dev/null +++ b/vllm/model_executor/models/fuyu.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py +# Copyright 2023 The vLLM team. +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Fuyu model.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict + +import torch +import torch.nn as nn +from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, + FuyuProcessor) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.models.persimmon import PersimmonForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) + +# Cannot find the following 2 numbers from hf config. +_IMAGE_TOKEN_ID = 71011 +_NEWLINE_TOKEN_ID = 71019 + + +class FuyuImagePatchInputs(TypedDict): + type: Literal["image_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + patches_per_image: list[int] + """ + The number of total patches for each image in the batch. + + This is used to split the embeddings which has the first two dimensions + flattened just like `flat_data`. + """ + + +class FuyuProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(FuyuConfig) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(FuyuProcessor, **kwargs) + + def get_image_processor(self) -> FuyuImageProcessor: + return self.get_hf_processor().image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + image_processor = self.get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] + patch_width = image_processor.patch_size["width"] + patch_height = image_processor.patch_size["height"] + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / patch_width) + nrows = math.ceil(image_height / patch_height) + return ncols, nrows + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + ncols, nrows = self.get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + ) + + return ncols * nrows + + def get_image_size_with_most_features(self) -> ImageSize: + image_processor = self.get_image_processor() + return ImageSize(width=image_processor.size["width"], + height=image_processor.size["height"]) + + +class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + target_width, target_height = \ + self.info.get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + image_patches = processed_outputs.get("image_patches") + if image_patches is not None: + images = mm_data["images"] + assert isinstance(images, list) + + # Original output: (1, num_images, Pn, Px * Py * C) + # New output: (num_images, Pn, Px * Py * C) + assert (isinstance(image_patches, list) + and len(image_patches) == 1) + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + + processed_outputs["image_patches"] = image_patches[0] + + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + # HF processor adds boa_token_id + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + boa_token_id = vocab["<0x04>"] + + return prompt_tokens + [boa_token_id] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(image_patches=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + bos_token_id = hf_config.bos_token_id + assert isinstance(bos_token_id, int) + + tokenizer = self.info.get_tokenizer() + eot_token_id = tokenizer.bos_token_id + assert isinstance(eot_token_id, int) + + def get_replacement_fuyu(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image_size.width, + image_height=image_size.height, + ) + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + + [_NEWLINE_TOKEN_ID]) * nrows + + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=_IMAGE_TOKEN_ID, + ) + + return [ + PromptReplacement( + modality="image", + target=[eot_token_id], + replacement=get_replacement_fuyu, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, + info=FuyuProcessingInfo, + dummy_inputs=FuyuDummyInputsBuilder) +class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_embed_tokens.": "vision_embed_tokens.", + "model.language_model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + self.vocab_size = config.text_config.vocab_size + self.image_token_id = _IMAGE_TOKEN_ID + self.image_feature_size = config.patch_size**2 * config.num_channels + + self.vision_embed_tokens = ColumnParallelLinear( + self.image_feature_size, + config.hidden_size, + quant_config=quant_config, + gather_output=True, + ) + self.language_model = PersimmonForCausalLM( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model"), + ) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.patch_size + num_channels = self.config.num_channels + expected_dims = num_channels * h * w + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data.to(self.vision_embed_tokens.weight.dtype) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: + image_patches = kwargs.pop("image_patches", None) + if image_patches is not None: + if not isinstance(image_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image patches. " + f"Got type: {type(image_patches)}") + + image_patches_flat = flatten_bn(image_patches) + + return FuyuImagePatchInputs( + type="image_patches", + flat_data=self._validate_pixel_values( + flatten_bn(image_patches_flat, concat=True)), + patches_per_image=[x.size(0) for x in image_patches_flat], + ) + + return None + + def _process_image_input( + self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + + assert self.vision_embed_tokens is not None + vision_embeddings_flat, _ = self.vision_embed_tokens( + image_patches_flat) + + return vision_embeddings_flat.split(patches_per_image, dim=0) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + _IMAGE_TOKEN_ID, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.language_model.logits_processor( + self.language_model.lm_head, hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py new file mode 100644 index 0000000..59c3102 --- /dev/null +++ b/vllm/model_executor/models/gemma.py @@ -0,0 +1,427 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2023 The vLLM team. +# Copyright (c) Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from functools import cache +from typing import Optional, Union + +import torch +from torch import nn +from transformers import GemmaConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +@cache +def _get_gemma_act_fn( + hidden_act: Optional[str], + hidden_activation: Optional[str], +) -> nn.Module: + if hidden_activation is None: + if hidden_act is not None: + logger.warning( + "Gemma's activation function was incorrectly set to exact GeLU " + "in the config JSON file when it was initially released. " + "Changing the activation function to approximate GeLU " + "(`gelu_pytorch_tanh`). If you want to use the legacy " + "`%s`, edit the config JSON to set " + "`hidden_activation=%s` instead of `hidden_act`. " + "See https://github.com/huggingface/transformers/pull/29402 " + "for more details.", hidden_act, hidden_act) + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu_pytorch_tanh": + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu": + return GeluAndMul(approximate="none") + else: + raise ValueError(f"Activation function {hidden_act} is not " + "supported for Gemma models.") + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: Optional[str] = None, + hidden_activation: Optional[str] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GemmaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int = 8192, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class GemmaDecoderLayer(nn.Module): + + def __init__( + self, + config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = GemmaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=getattr(config, "hidden_activation", None), + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class GemmaModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GemmaDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", + torch.tensor(normalizer), + persistent=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = GemmaModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py new file mode 100644 index 0000000..8beefb2 --- /dev/null +++ b/vllm/model_executor/models/gemma2.py @@ -0,0 +1,427 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 The vLLM team. +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import Gemma2Config + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): + raise ValueError( + "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma2Attention(nn.Module): + + def __init__(self, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + + # reference: + # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa + layer_idx = extract_layer_index(prefix) + use_sliding_window = (layer_idx % 2 == 0 and getattr( + config, "interleaved_sliding_window", None) is not None) + sliding_window = config.interleaved_sliding_window if \ + use_sliding_window else None + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma2DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma2Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=config.attn_logit_softcapping, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma2DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", + torch.tensor(normalizer), + persistent=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py new file mode 100644 index 0000000..954e48d --- /dev/null +++ b/vllm/model_executor/models/gemma3.py @@ -0,0 +1,535 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Gemma3TextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3Attention(nn.Module): + + def __init__(self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # TODO(woosuk): Add reference to the original HF implementation. + layer_idx = extract_layer_index(prefix) + self.is_sliding = (getattr( + config, "interleaved_sliding_window", None) is not None and bool( + (layer_idx + 1) % config.sliding_window_pattern)) + # Initialize the rotary embedding. + if self.is_sliding: + # Local attention. Override the values in config.json. + self.rope_theta = config.rope_local_base_freq + self.rope_scaling = {"rope_type": "default"} + self.sliding_window = config.interleaved_sliding_window + else: + # Global attention. Use the values in config.json. + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.sliding_window = None + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + + # Initialize the attention. + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + if not kwargs.get("has_images", False): + # Fast path for text-only inputs. The performance for the text-only + # inputs are not affected by the naive attention below. + output, _ = self.o_proj(attn_output) + return output + + # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens + # that correspond to the same image while using causal attention + # otherwise. Current attention backends cannot handle this pattern, so + # we temporarily use a naive attention implementation with mask tensors. + + # We intentionally keep the attention backend as-is and only override + # `attn_output` with the naive implementation's output. This minimizes + # changes to existing model runners and attention backends. The call to + # `self.attn(q, k, v)` is only used to populate the KV cache - its + # output is discarded and overwritten below. While this duplicates + # computation, it maintains compatibility. + # TODO(woosuk): Optimize by implementing custom attention kernels. + attn_output = self.naive_attn_with_masks(q, + k, + v, + out=attn_output, + **kwargs) + output, _ = self.o_proj(attn_output) + return output + + def naive_attn_with_masks( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # NOTE(woosuk): As described in the comment above, this code is not + # meant to be performant. It is only meant to be correct. + q = q.view(-1, self.num_heads, self.head_dim) + # Expand the key and value to handle GQA. + num_queries_per_kv = self.num_heads // self.num_kv_heads + k = k.view(-1, self.num_kv_heads, self.head_dim) + k = k.repeat_interleave(num_queries_per_kv, dim=-2) + v = v.view(-1, self.num_kv_heads, self.head_dim) + v = v.repeat_interleave(num_queries_per_kv, dim=-2) + + if self.is_sliding: + attn_masks = kwargs["local_attn_masks"] + else: + attn_masks = kwargs["global_attn_masks"] + + seq_lens = kwargs["seq_lens"] + start_idx = 0 + for seq_len, attn_mask in zip(seq_lens, attn_masks): + end_idx = start_idx + seq_len + query = q[start_idx:end_idx].unsqueeze(0) + key = k[start_idx:end_idx].unsqueeze(0) + value = v[start_idx:end_idx].unsqueeze(0) + + # Transpose. + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask, + self.scaling, + ) + output = output.transpose(1, 2).flatten(-2, -1) + out[start_idx:end_idx] = output + start_idx = end_idx + return out + + +class Gemma3DecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3TextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=None, + prefix=f"{prefix}.self_attn", + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Gemma3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", + torch.tensor(normalizer), + persistent=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # NOTE(woosuk): Only apply the normalizer to the output of + # vocab embedding. Don't apply it to the vision embedding. + return self.embed_tokens(input_ids) * self.normalizer + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + **kwargs, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py new file mode 100644 index 0000000..d14f5fa --- /dev/null +++ b/vllm/model_executor/models/gemma3_mm.py @@ -0,0 +1,729 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Optional, TypedDict + +import torch +from torch import nn +from transformers import BatchFeature, Gemma3Config, Gemma3Processor +from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +# yapf: disable +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, BoundPromptUpdate, + PlaceholderFeaturesInfo, + PromptReplacement, PromptTargetMatch, + PromptUpdate, PromptUpdateDetails, + find_mm_placeholders, + replace_token_matches) +# yapf: enable +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class Gemma3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(num_patches_total, num_channels, height, width)` + + `num_patches_total` is the total number of patches + over each image over each prompt in the batch. + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +Gemma3ImageInputs = Gemma3ImagePixelInputs + + +class Gemma3ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma3Config) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def _resolve_image_kwargs( + self, + processor: Gemma3Processor, + keys: set[str], + ) -> dict[str, Any]: + image_processor = processor.image_processor + kwargs = processor._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=processor.tokenizer.init_kwargs, + ) + + images_kwargs = kwargs["images_kwargs"] + + def _resolve_kw(key: str): + val = getattr(image_processor, key) + if val is None: + val = images_kwargs[key] + + return val + + return {k: _resolve_kw(k) for k in keys} + + def get_num_crops( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3Processor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + images_kwargs = self._resolve_image_kwargs( + processor, { + "do_pan_and_scan", "pan_and_scan_min_crop_size", + "pan_and_scan_max_num_crops", + "pan_and_scan_min_ratio_to_activate" + }) + + do_pan_and_scan = images_kwargs["do_pan_and_scan"] + pan_and_scan_min_crop_size = images_kwargs[ + "pan_and_scan_min_crop_size"] + pan_and_scan_max_num_crops = images_kwargs[ + "pan_and_scan_max_num_crops"] + pan_and_scan_min_ratio_to_activate = images_kwargs[ + "pan_and_scan_min_ratio_to_activate"] + + if not do_pan_and_scan: + return 0 + + if envs.VLLM_USE_V1: + logger.warning_once( + "`do_pan_and_scan=True` has suboptimal results on V1 " + "because of the simplified attention pattern being used.") + + # Based on Gemma3ImageProcessor.pan_and_scan + if image_width >= image_height: + if image_width / image_height < pan_and_scan_min_ratio_to_activate: + return 0 + + num_crops_w = min( + int(math.floor(image_width / pan_and_scan_min_crop_size)), + int(math.floor(image_width / image_height + 0.5)), + ) + + num_crops_w = max(2, num_crops_w) + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) + num_crops_h = 1 + else: + if image_height / image_width < pan_and_scan_min_ratio_to_activate: + return 0 + + num_crops_h = min( + int(math.floor(image_height / pan_and_scan_min_crop_size)), + int(math.floor(image_height / image_width + 0.5)), + ) + + num_crops_h = max(2, num_crops_h) + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) + num_crops_w = 1 + + crop_size_w = int(math.ceil(image_width / num_crops_w)) + crop_size_h = int(math.ceil(image_height / num_crops_h)) + + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: + return 0 + + return num_crops_w * num_crops_h + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3Processor], + ) -> PromptUpdateDetails[str]: + if processor is None: + processor = self.get_hf_processor() + + boi_token = processor.boi_token + + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + if num_crops == 0: + image_text = boi_token + else: + crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) + image_text = ( + f"Here is the original image {boi_token} and here are some " + f"crops to help you see better {crops_image_tokens}") + + repl_full = image_text.replace(boi_token, + processor.full_image_sequence) + + tokenizer = processor.tokenizer + vocab = tokenizer.get_vocab() + image_token_id = vocab[tokenizer.image_token] + + return PromptUpdateDetails.select_token_id(repl_full, image_token_id) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Gemma3Processor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + image_seq_len = processor.image_seq_length + + return (num_crops + 1) * image_seq_len + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + images_kwargs = self._resolve_image_kwargs( + processor, {"pan_and_scan_max_num_crops"}) + max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] + + # Result in the max possible feature size (h:w = max_num_crops:1) + return ImageSize(height=50 * max_num_crops, width=50) + + +class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.boi_token + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + + # HF processor pops the `num_crops` kwarg, which is needed by vLLM + if (images := mm_data.get("images")) is not None: + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": + images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) + for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + num_crops = [ + self.info.get_num_crops(image_width=size.width, + image_height=size.height, + processor=hf_processor) + for size in image_sizes + ] + processed_outputs["num_crops"] = torch.tensor(num_crops) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_crops = hf_inputs.get("num_crops", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_crops + 1), + num_crops=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token = hf_processor.boi_token + + def get_replacement_gemma3(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_gemma3, + ) + ] + + def _apply_token_matches( + self, + prompt: list[int], + mm_matches: Mapping[str, Sequence[PromptTargetMatch]], + mm_item_counts: Mapping[str, int], + ) -> list[int]: + token_ids = super()._apply_token_matches( + prompt, + mm_matches, + mm_item_counts, + ) + + # "\n\n\n" and "\n\n\n\n" are single tokens + # Since our replacement can insert "\n\n" next to "\n" + # tokens, we have to combine them to be consistent with + # the output of the tokenizer + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + token_ids = replace_token_matches( + token_ids, + [newline_1, newline_2], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_1], + [newline_3], + ) + token_ids = replace_token_matches( + token_ids, + [newline_2, newline_2], + [newline_4], + ) + + return token_ids + + def _find_mm_placeholders( + self, + mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], + new_token_ids: list[int], + mm_item_counts: Mapping[str, int], + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + newline_1 = vocab["\n"] + newline_2 = vocab["\n\n"] + newline_3 = vocab["\n\n\n"] + newline_4 = vocab["\n\n\n\n"] + + def get_repl_toks(tok: int) -> list[int]: + if tok == newline_3: + return [newline_1, newline_2] + if tok == newline_4: + return [newline_2, newline_2] + + return [tok] + + repl_token_ids = list[int]() + repl_orig_idxs = list[int]() + for orig_idx, orig_tok in enumerate(new_token_ids): + repl_toks = get_repl_toks(orig_tok) + repl_token_ids.extend(repl_toks) + repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) + + repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, + mm_item_counts) + + return { + modality: [ + PlaceholderFeaturesInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=repl_orig_idxs[p.start_idx], + tokens=p.tokens, + is_embed=p.is_embed, + ) for p in placeholders + ] + for modality, placeholders in repls.items() + } + + +class Gemma3MultiModalProjector(nn.Module): + + def __init__(self, config: Gemma3Config): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.zeros(config.vision_config.hidden_size, + config.text_config.hidden_size)) + + self.mm_soft_emb_norm = GemmaRMSNorm( + config.vision_config.hidden_size, + eps=config.vision_config.layer_norm_eps) + + self.patches_per_image = int(config.vision_config.image_size // + config.vision_config.patch_size) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, + stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, + self.patches_per_image) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul( + normed_vision_outputs, self.mm_input_projection_weight) + return projected_vision_outputs.type_as(vision_outputs) + + +@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder) +class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "lm_head.": "language_model.lm_head.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.sliding_window = getattr(config.text_config, + "interleaved_sliding_window", None) + + self.vision_tower = SiglipVisionModel(config.vision_config, + quant_config, + prefix=maybe_prefix( + prefix, "vision_tower")) + self.multi_modal_projector = Gemma3MultiModalProjector(config) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma3ForCausalLM"], + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + image_size = self.config.vision_config.image_size + expected_dims = (3, image_size, image_size) + if data.shape[1:] != expected_dims: + raise ValueError( + "The expected shape of pixel values per image per batch is " + f"{expected_dims}. You supplied {tuple(data.shape)}.") + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + num_crops = kwargs.pop("num_crops", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma3 does not support image_embeds." + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(num_crops, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_crops. " + f"Got type: {type(num_crops)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + num_crops = flatten_bn(num_crops, concat=True) + + return Gemma3ImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + num_patches=num_crops + 1, + ) + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + return vision_tower(pixel_values) + + def _process_image_input( + self, + image_input: Gemma3ImageInputs, + ) -> list[torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input["pixel_values"] + num_patches = image_input["num_patches"] + + image_features = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + image_embeds = self.multi_modal_projector(image_features) + + return [ + e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) + ] + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + if vision_embeddings is not None: + kwargs = self.prepare_attn_masks( + input_ids, + positions, + mask_dtype=self.dtype, + **kwargs, + ) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs) + + return hidden_states + + def prepare_attn_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + **kwargs, + ): + kwargs["has_images"] = True + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_indices = (positions == 0).cpu().nonzero() + num_seqs = len(start_indices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_indices[i].item() + if i < num_seqs - 1: + end_idx = start_indices[i + 1].item() + else: + end_idx = len(input_ids) + seq_lens.append(end_idx - start_idx) + kwargs["seq_lens"] = seq_lens + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_len in seq_lens: + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + start_idx = end_idx + # Create a global causal mask. + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0. + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Consider the bidirectional attention between image tokens. + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_token_ids == self.config.image_token_index) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + if self.sliding_window is not None: + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + local_attn_masks.append(local_attn_mask) + kwargs["global_attn_masks"] = global_attn_masks + kwargs["local_attn_masks"] = local_attn_masks + return kwargs + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_tower") diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py new file mode 100644 index 0000000..7d16332 --- /dev/null +++ b/vllm/model_executor/models/gemma3n.py @@ -0,0 +1,811 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, + GeluAndMul, + GeluAndMulSparse) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, make_layers, maybe_prefix) + +logger = init_logger(__name__) + + +class Gemma3nAltUp(nn.Module): + """Alternating updates (Altup) + The AltUp module wraps transformer layers. The `predict` step modifies the + input to the transformer layer, and the `correct` step propagates the output + of the transformer layer to the sparsely updated dimensions. + See more in the research paper: + https://proceedings.neurips.cc/paper_files/paper/2023/file/f2059277ac6ce66e7e5543001afa8bb5-Paper-Conference.pdf + """ + + def __init__( + self, + hidden_size: int, + rms_norm_eps: float, + altup_num_inputs: int, + altup_coef_clip: float, + altup_active_idx: int, + prefix: str, + ): + super().__init__() + + self.altup_num_inputs = altup_num_inputs + self.altup_active_idx = altup_active_idx + self.altup_coef_clip = altup_coef_clip + + self.correction_coefs = ReplicatedLinear( + altup_num_inputs, + altup_num_inputs, + bias=False, + prefix=f"{prefix}.correction_coefs", + return_bias=False, + ) + self.prediction_coefs = ReplicatedLinear( + altup_num_inputs, + altup_num_inputs**2, + bias=False, + prefix=f"{prefix}.prediction_coefs", + return_bias=False, + ) + self.modality_router = ReplicatedLinear( + hidden_size, + altup_num_inputs, + bias=False, + prefix=f"{prefix}.modality_router", + return_bias=False, + ) + self.router_norm = RMSNorm( + hidden_size=hidden_size, + eps=rms_norm_eps, + ) + self.router_input_scale = torch.tensor( + hidden_size**-1.0, dtype=self.modality_router.weight.dtype) + self.correct_output_scale = nn.Parameter( + torch.zeros(hidden_size, dtype=torch.float32)) + + def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: + router_inputs = self.router_norm(x) * self.router_input_scale + routed = self.modality_router(router_inputs) + return torch.tanh(routed.float()).type_as(x) + + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: + return (corrected.type_as(self.correct_output_scale) * + self.correct_output_scale).type_as(corrected) + + def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden: [altup_num_inputs, num_tokens, hidden_size] + # modalities: [num_tokens, num_altup_inputs] + # all_coefs: [num_tokens, num_altup_inputs ** 2] + modalities = self._compute_router_modalities( + hidden_states[self.altup_active_idx]) + all_coefs = self.prediction_coefs(modalities) + + # Reshape and transpose the 2D matrix for the matmul. + # all_coefs_T: [num_tokens, num_altup_inputs, num_altup_inputs] + all_coefs_T = all_coefs.reshape( + -1, + self.altup_num_inputs, + self.altup_num_inputs, + ).permute(0, 2, 1) + + # hidden_states to [num_tokens, hidden_size, altup_num_inputs] + predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs_T) + # [altup_num_inputs, num_tokens, hidden_size] + predictions = predictions.permute(2, 0, 1) + predictions += hidden_states + return predictions.contiguous() + + def correct(self, predictions: torch.Tensor, + activated: torch.Tensor) -> torch.Tensor: + # predictions: [altup_num_inputs, num_tokens, hidden_size] + # activated: [num_tokens, hidden_size] + # modalities: [num_tokens, altup_num_inputs] + modalities = self._compute_router_modalities(activated) + # innovation: [num_tokens, altup_num_inputs] + innovation = activated - predictions[self.altup_active_idx] + # innovation: [altup_num_inputs, num_tokens, hidden_size] + innovation = innovation.repeat(self.altup_num_inputs, 1, 1) + + # Permute to [altup_num_inputs, num_tokens] as the last dim + # is a scalar applied to each altup input and expand on + # num_tokens dim for broadcastability over hidden_size. + # all_coefs: [num_tokens, altup_num_inputs] + all_coefs = self.correction_coefs(modalities) + 1.0 + # all_coefs: [altup_num_inputs, num_tokens, 1] + all_coefs = all_coefs.T.unsqueeze(-1) + + # Elementwise (broadcast over hidden_size). + corrected = torch.mul(innovation, all_coefs) + corrected += predictions + + return corrected.contiguous() + + +class Gemma3nLaurelBlock(nn.Module): + """Learned Augmented Residual Layer""" + + def __init__(self, hidden_size: int, laurel_rank: int, rms_norm_eps: float, + prefix: str): + super().__init__() + + self.linear_left = ColumnParallelLinear( + hidden_size, + laurel_rank, + bias=False, + prefix=f"{prefix}.linear_left", + return_bias=False, + ) + self.linear_right = RowParallelLinear(laurel_rank, + hidden_size, + bias=False, + prefix=f"{prefix}.linear_right", + return_bias=False) + self.post_laurel_norm = RMSNorm( + hidden_size=hidden_size, + eps=rms_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + laurel_x = self.linear_left(x) + laurel_x = self.linear_right(laurel_x) + normed_laurel_x = self.post_laurel_norm(laurel_x) + return x + normed_laurel_x + + +class Gemma3nMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + activation_sparsity: float = 0.0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + + self.act_fn = GeluAndMulSparse( + activation_sparsity=activation_sparsity, + approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul( + approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma3nAttention(nn.Module): + + def __init__(self, + config: Gemma3nTextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.q_norm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps) + self.k_norm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps) + self.v_norm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps, + has_weight=False) + + layer_idx = extract_layer_index(prefix) + if config.layer_types[layer_idx] == "sliding_attention": + self.sliding_window = config.sliding_window + rope_theta = config.rope_local_base_freq + rope_scaling = {"rope_type": "default"} + else: + self.sliding_window = None + rope_theta = config.rope_theta + rope_scaling = config.rope_scaling + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx + + if self.is_kv_shared: + # Last full attention layer is 1 before sharing + # Last sliding attention layer is 2 before sharing + offset = 2 if self.sliding_window is not None else 1 + kv_shared_layer_index = first_kv_shared_layer_idx - offset + kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 + else: + kv_sharing_target_layer_name = None + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=True, + rope_scaling=rope_scaling, + ) + + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=1.0, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + v = v.flatten(-2, -1) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + output, _ = self.o_proj(attn_output) + return output + + +class Gemma3nDecoderLayer(nn.Module): + + def __init__( + self, + config: Gemma3nTextConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.altup_active_idx = config.altup_active_idx + assert config.altup_correct_scale + + self.altup = Gemma3nAltUp( + hidden_size=config.hidden_size, + rms_norm_eps=config.rms_norm_eps, + altup_num_inputs=config.altup_num_inputs, + altup_coef_clip=config.altup_coef_clip, + altup_active_idx=config.altup_active_idx, + prefix=f"{prefix}.altup", + ) + self.self_attn = Gemma3nAttention( + config=config, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Gemma3nMLP( + hidden_size=config.hidden_size, + # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501 + intermediate_size=config.intermediate_size[extract_layer_index( + prefix)], + hidden_activation=config.hidden_activation, + quant_config=quant_config, + activation_sparsity=config.activation_sparsity_pattern[ + extract_layer_index(prefix)], + prefix=f"{prefix}.mlp", + ) + self.laurel = Gemma3nLaurelBlock( + hidden_size=config.hidden_size, + laurel_rank=config.laurel_rank, + rms_norm_eps=config.rms_norm_eps, + prefix=f"{prefix}.laurel", + ) + + # NOTE(rob): should be ColumnParallelLinear and RowParallelLinear + # But, we need to add per_layer_input_gate(x) to per_layer_input. + # per_layer_input cannot be sharded, so we replicate for now. + self.per_layer_input_gate = ReplicatedLinear( + config.hidden_size, + config.hidden_size_per_layer_input, + bias=False, + prefix=f"{prefix}.per_layer_input_gate", + return_bias=False, + ) + self.per_layer_projection = ReplicatedLinear( + config.hidden_size_per_layer_input, + config.hidden_size, + bias=False, + prefix=f"{prefix}.per_layer_projection", + return_bias=False, + ) + + # LayerNorms. + self.input_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_per_layer_input_norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + self.act_fn = _ACTIVATION_REGISTRY[config.hidden_activation] + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_input: torch.Tensor, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + + # ActUp (predict). + predictions = self.altup.predict(hidden_states) + active_prediction = predictions[self.altup_active_idx] + active_prediction_normed = self.input_layernorm(active_prediction) + laurel_output = self.laurel(active_prediction_normed) + + # Attention. + attn = self.self_attn( + positions=positions, + hidden_states=active_prediction_normed, + **kwargs, + ) + attn = self.post_attention_layernorm(attn) + attn_gated = attn + active_prediction + attn_laurel = (attn_gated + laurel_output) / torch.sqrt( + torch.tensor(2.0)) + + # MLP. + attn_norm = self.pre_feedforward_layernorm(attn_laurel) + attn_ffw = self.mlp(attn_norm) + attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw) + attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm + + # ActUp (connect). + corrected_predictions = self.altup.correct(predictions, + attn_ffw_laurel_gated) + first_prediction = corrected_predictions[self.altup_active_idx] + first_prediction = self.altup.scale_corrected_output(first_prediction) + + # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) + first_prediction = self.per_layer_input_gate(first_prediction) + first_prediction = self.act_fn(first_prediction) + first_prediction = torch.mul(first_prediction, per_layer_input) + + # per_layer_projection adapted from jax.numpy.einsum("btp,pd->btd", ...) + first_prediction = self.per_layer_projection(first_prediction) + first_prediction = self.post_per_layer_input_norm(first_prediction) + corrected_predictions[1:] += first_prediction + + return corrected_predictions + + +@support_torch_compile +class Gemma3nTextModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config.text_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.embed_scale = torch.tensor( + config.hidden_size**0.5, + dtype=self.embed_tokens.weight.dtype, + ) + self.embed_tokens_per_layer = VocabParallelEmbedding( + config.vocab_size_per_layer_input, + config.num_hidden_layers * config.hidden_size_per_layer_input, + prefix=f"{prefix}.per_layer_embed_tokens", + ) + self.embed_scale_per_layer = torch.tensor( + config.hidden_size_per_layer_input**0.5, + dtype=self.embed_tokens.weight.dtype, + ) + self.per_layer_model_projection = ColumnParallelLinear( + config.hidden_size, + config.num_hidden_layers * config.hidden_size_per_layer_input, + bias=False, + gather_output=True, + return_bias=False, + prefix=f"{prefix}.per_layer_model_projection", + ) + self.per_layer_projection_norm = RMSNorm( + hidden_size=config.hidden_size_per_layer_input, + eps=config.rms_norm_eps, + ) + self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to( + self.embed_tokens.weight.dtype) + self.per_layer_projection_scale = torch.tensor( + config.hidden_size**0.5, + dtype=self.embed_tokens.weight.dtype, + ) + self.altup_projections = nn.ModuleList([ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + prefix=f"{prefix}.{idx-1}.altup_projections", + ) for idx in range(1, self.config.altup_num_inputs) + ]) + self.altup_unembed_projections = nn.ModuleList([ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + prefix=f"{prefix}.{idx-1}.altup_unembed_projections", + ) for idx in range(1, self.config.altup_num_inputs) + ]) + + # Transformer blocks. + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma3nDecoderLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.eps = torch.tensor(torch.finfo().min) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + # Deal with the fact that vocab_size_per_layer_input < vocab_size + # which causes us to have some out of vocab tokens by setting + # those token ids to 0. This matches the HF implementation. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, + torch.zeros_like(input_ids)) + return self.embed_tokens_per_layer( + per_layer_inputs_tokens) * self.embed_scale_per_layer + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + # Per layer inputs. + if input_ids is None: + raise ValueError("Passing None for input ids is not supported.") + per_layer_inputs = self.get_per_layer_input_embeddings(input_ids) + per_layer_inputs = per_layer_inputs.reshape( + -1, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input) + per_layer_projection = self.per_layer_model_projection(hidden_states_0) + per_layer_projection = per_layer_projection.reshape( + *hidden_states_0.shape[:-1], + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection) + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + + # Altup embed. + hidden_states = [hidden_states_0] * self.config.altup_num_inputs + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + hidden_states = torch.stack(hidden_states, dim=0) + + # Transformer blocks. + for layer_idx, layer in enumerate(self.layers): + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # Altup unembed. + target_magnitude = torch.mean(hidden_states[0]**2, + dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_unembed_projections[i - 1]( + hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=0) + + return self.norm(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + # Avoid spurious match with ".up_proj". + if "altup_projections" in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma3nModel(nn.Module): + + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.language_model = Gemma3nTextModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "language_model")) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.language_model(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + **kwargs) + + +class Gemma3nForConditionalGeneration(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + lora_config = vllm_config.lora_config + del lora_config # Unused. + super().__init__() + self.config = config + self.model = Gemma3nModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + config.text_config.vocab_size, + soft_cap=config.text_config.final_logit_softcapping) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.language_model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: Optional[SamplingMetadata], + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.language_model.embed_tokens, + hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_substrs=([ + "embed_audio.", "embed_vision.", + "audio_tower.", "vision_tower." + ])) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py new file mode 100644 index 0000000..defa77b --- /dev/null +++ b/vllm/model_executor/models/glm.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only HF format GLM-4 model compatible with THUDM weights.""" +from vllm.config import VllmConfig +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import PPMissingLayer + + +class GlmForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + vllm_config.model_config.hf_config.partial_rotary_factor = 0.5 + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Hack Llama model to fit HF format GLM implementation + # Attention difference between GLM and Llama: + # 1. Half partial rotary_dim and no Neox style. + # 2. There is no bias for o_proj in attention + for layer in self.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.is_neox_style = False + layer.self_attn.o_proj.bias = None + layer.self_attn.o_proj.skip_bias_add = True diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py new file mode 100644 index 0000000..11a5d0a --- /dev/null +++ b/vllm/model_executor/models/glm4.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Zhipu AI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" + +import os + +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import Glm4Config + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .llama import LlamaMLP as Glm4MLP +from .llama import LlamaModel +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + +from vllm.utils import W8a8GetCacheJSON +from vllm import _custom_ops as ops +from vllm.model_executor.utils import pad_weight, gemm_bank_conf + +class Glm4Attention(nn.Module): + + def __init__(self, + config: Glm4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.rotary_dim = self.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4DecoderLayer(nn.Module): + + def __init__( + self, + config: Glm4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + self.self_attn = Glm4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=AttentionType.DECODER, + ) + self.mlp = Glm4MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Glm4DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Glm4Model(LlamaModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=Glm4DecoderLayer) + + +class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Glm4Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + self.tritonsingleton= W8a8GetCacheJSON() + self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' + # self.use_lm_nn = os.environ.get('LM_NN') == '1' + self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' + self.use_fa_pad = os.environ.get('FA_PAD') == '1' + self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' + self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py new file mode 100644 index 0000000..a3908e3 --- /dev/null +++ b/vllm/model_executor/models/glm4_1v.py @@ -0,0 +1,1590 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py +# Copyright 2025 The vLLM team. +# Copyright 2025 The ZhipuAI Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4V model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature +from transformers.models.glm4v.configuration_glm4v import (Glm4vConfig, + Glm4vVisionConfig) +from transformers.models.glm4v.image_processing_glm4v import ( + Glm4vImageProcessor, smart_resize) +from transformers.models.glm4v.video_processing_glm4v import ( + Glm4vVideoProcessor) +from transformers.video_utils import VideoMetadata + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, VideoItem) +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope + +from ..layers.activation import SiluAndMul +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +# For profile run +_MAX_FRAMES_PER_VIDEO = 600 + +# === Vision Inputs === # + + +class Glm4vImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Glm4vImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Glm4vImageInputs = Union[Glm4vImagePixelInputs, Glm4vImageEmbeddingInputs] + + +class Glm4vVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + # video_metadata: Union[list[VideoMetadata], list[dict]] + video_grid_thw: Union[list[torch.Tensor], torch.Tensor] + """Shape: `(num_videos, num_frames, 3)` or `(1, num_frames, 3)` + for single video. + Each entry represents [grid_t, grid_h, grid_w] format where: + - grid_t: Temporal grid size (usually 1 for processed video) + - grid_h: Height grid size + - grid_w: Width grid size + This describes the grid structure of the video patches. + """ + + +class Glm4vVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + + video_embeds: torch.Tensor + """ + Tensor shape: `(num_video_patches, hidden_size)` + - `num_video_patches`: Total number of video patches across all frames + - `hidden_size`: Must match the hidden size of language model backbone + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 1, 3)` or `(1, 1, 3)` for single video + Each entry represents [grid_t, grid_h, grid_w] format where: + - grid_t: Temporal grid size (usually 1 for processed video) + - grid_h: Height grid size + - grid_w: Width grid size + This describes the grid structure of the video patches. + """ + + +Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs] + +# === Vision Encoder === # + + +class Glm4vVisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather( + gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group, + ) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class Glm4vVisionAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + bias=False, + ) + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + }: + raise RuntimeError( + f"GLM-4V does not support {self.attn_backend} backend now.") + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial( + dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size, + ) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.attn_backend == _Backend.FLASH_ATTN: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False, + ) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Glm4vVisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Glm4vVisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.mlp = Glm4vVisionMLP( + dim, + mlp_hidden_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Glm4vVisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 1, + in_channels: int = 3, + hidden_size: int = 1536, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Glm4vPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.hidden_size = d_model + self.proj = ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=bias, + gather_output=True) + self.post_projection_norm = nn.LayerNorm(self.hidden_size) + self.gate_up_proj = MergedColumnParallelLinear( + input_size=self.hidden_size, + output_sizes=[context_dim] * 2, + bias=bias, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + context_dim, + self.hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.act_fn = SiluAndMul() + self.extra_activation_func = nn.GELU() + + def forward(self, x: torch.Tensor): + x, _ = self.proj(x) + x = self.extra_activation_func(self.post_projection_norm(x)) + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Glm4vVisionEmbeddings(nn.Module): + + def __init__(self, config: Glm4vVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, embeddings, lengths, image_shapes, h_coords, + w_coords) -> torch.Tensor: + pos_embed_weight = self.position_embedding.weight + hidden_size = pos_embed_weight.shape[1] + total_seq = h_coords.shape[0] + device = pos_embed_weight.device + + # Move coordinates to correct device + h_coords, w_coords = h_coords.to(device), w_coords.to(device) + + # Handle empty sequence case + if total_seq == 0: + adapted_pos_embed = torch.empty(0, + hidden_size, + device=device, + dtype=pos_embed_weight.dtype) + else: + # Convert inputs to tensors if needed + if isinstance(lengths, list): + lengths = torch.tensor(lengths, + device=device, + dtype=torch.long) + if not isinstance(image_shapes, torch.Tensor): + image_shapes = torch.tensor(image_shapes, + device=device, + dtype=torch.long) + + # Prepare 2D position embedding + orig_size_sq = pos_embed_weight.shape[0] + orig_size = int(orig_size_sq**0.5) + pos_embed_2d = (pos_embed_weight.view( + orig_size, orig_size, + hidden_size).permute(2, 0, + 1).unsqueeze(0).to(device=device, + dtype=torch.float32)) + + # Calculate target dimensions for each patch + target_h = torch.cat([ + image_shapes[i, 1].repeat(lengths[i]) + for i in range(len(lengths)) + ]).to(device=device, dtype=torch.float32) + target_w = torch.cat([ + image_shapes[i, 2].repeat(lengths[i]) + for i in range(len(lengths)) + ]).to(device=device, dtype=torch.float32) + + # Normalize coordinates to [-1, 1] range for grid_sample + h_coords = h_coords.to(device=device, dtype=torch.float32) + w_coords = w_coords.to(device=device, dtype=torch.float32) + norm_w = ((w_coords + 0.5) / target_w) * 2 - 1 + norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 + + # Create sampling grid + grid = (torch.stack((norm_w, norm_h), + dim=-1).unsqueeze(0).unsqueeze(2)) + + # Perform bicubic interpolation + interpolated_embed_fp32 = F.grid_sample( + pos_embed_2d, + grid, + mode="bicubic", + align_corners=False, + padding_mode="border", + ) + + # Reshape and convert back to original dtype + adapted_pos_embed_fp32 = ( + interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)) + adapted_pos_embed = adapted_pos_embed_fp32.to( + pos_embed_weight.dtype).to(embeddings.device) + + # Add adapted position encoding to embeddings + embeddings = embeddings + adapted_pos_embed + return embeddings + + +class Glm4vVisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, + self.dim, + 2, + dtype=torch.float, + device=self.inv_freq.device, + ) / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Glm4vVisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Glm4vVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.out_hidden_size = vision_config.out_hidden_size + + self.patch_embed = Glm4vVisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + norm_layer = partial(RMSNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) + self.blocks = nn.ModuleList([ + Glm4vVisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.out_hidden_size, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) for layer_idx in range(depth) + ]) + self.merger = Glm4vPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=vision_config.intermediate_size, + quant_config=quant_config, + bias=False, + ) + self.embeddings = Glm4vVisionEmbeddings(vision_config) + + self.post_conv_layernorm = RMSNorm(vision_config.hidden_size, + eps=vision_config.rms_norm_eps) + self.downsample = nn.Conv2d( + in_channels=vision_config.hidden_size, + out_channels=vision_config.out_hidden_size, + kernel_size=vision_config.spatial_merge_size, + stride=vision_config.spatial_merge_size, + ) + self.post_layernorm = RMSNorm(vision_config.hidden_size, + eps=vision_config.rms_norm_eps) + + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = (hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + wpos_ids = (wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb, pos_ids + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + x = self.post_conv_layernorm(x) + + # compute position embedding + rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0], + image_type_ids[:, 1]) + + # transformers + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + # adapter + x = self.post_layernorm(x) + + x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, + x.shape[-1]) + x = x.permute(0, 3, 1, 2) + x = self.downsample(x).view(-1, self.out_hidden_size) + x = self.merger(x) + + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Glm4vProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Glm4vConfig) + + def get_tokenizer(self): + return self.ctx.tokenizer + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + def get_image_processor(self) -> Glm4vImageProcessor: + return self.get_hf_processor().image_processor + + def get_video_processor(self) -> Glm4vVideoProcessor: + return self.get_hf_processor().video_processor + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 16, + do_resize: bool = True, + max_image_pixels: int = 28 * 28 * 2 * 30000, + ) -> tuple[ImageSize, int]: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size + if do_resize: + resized_height, resized_width = smart_resize( + num_frames=num_frames + if num_frames > temporal_patch_size else temporal_patch_size, + height=image_height, + width=image_width, + factor=patch_size * merge_size, + max_pixels=max_image_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info(image_width=9999999, + image_height=9999999) + return max_image_size + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + max_image_pixels=28 * 28 * 2 * 6144, + ) + return num_image_tokens + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + max_image_pixels=28 * 28 * 2 * 30000, + ) + return num_video_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + ) + if next_max_tokens > max_tokens or next_max_tokens == 0: + break + + num_frames = next_num_frames + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 1) + + def _get_video_second_idx(self, metadata: dict[str, Any], + total_frames: int) -> list[int]: + video_processor = self.get_video_processor() + + video_fps = metadata.get("fps", 2.0) + meta_frames = metadata.get("total_num_frames", total_frames) + max_frame_idx = meta_frames - 1 + duration = metadata.get("duration", + round(max_frame_idx / video_fps) + 1) + if duration <= video_processor.max_duration: + n = int(math.floor(duration * video_processor.fps)) + frame_indices = [ + min( + max_frame_idx, + int(math.ceil(i * video_fps / video_processor.fps)), + ) for i in range(n) + ] + else: + num_samples = int(video_processor.max_duration * + video_processor.fps) + if num_samples >= meta_frames: + frame_indices = list(range(meta_frames)) + else: + target_seconds = np.linspace(0, + duration, + num_samples, + endpoint=True) + frame_indices = [ + min(max_frame_idx, int(math.ceil(t * video_fps))) + for t in target_seconds + ] + + seen, uniq = set(), [] + for idx in frame_indices: + if idx not in seen: + seen.add(idx) + uniq.append(idx) + if len(uniq) & 1: + uniq.append(uniq[-1]) + frame_indices = uniq + + full_second_idxs = [int(idx / video_fps) for idx in frame_indices] + timestamps_list = full_second_idxs[::2] + selected_timestamps = [] + for idx in range(0, len(timestamps_list)): + selected_timestamps.append(timestamps_list[idx]) + return selected_timestamps + + +class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_config = self.info.get_hf_config() + hf_processor = self.info.get_hf_processor() + tokenizer = self.info.get_tokenizer() + + image_token: str = hf_processor.image_token + video_token_ids = [ + hf_config.video_start_token_id, + hf_processor.video_token_id, + hf_config.video_end_token_id, + ] + video_token = tokenizer.decode(video_token_ids) + + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts) + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ), + } + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[VideoItem]: + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) + video_items = [] + for i in range(num_videos): + video_metadata = { + "fps": 2.0, + "duration": num_frames / 2.0, + "total_num_frames": num_frames, + "video_backend": "opencv", + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + + return video_items + + +class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + processor = self.info.get_hf_processor(**mm_kwargs) + + # GLM-4.1V use `image_token_id` as video placeholder, we need to + # replace it with `video_token_id` for video processing. So we + # separate video processing from image processing. + if ("videos" in mm_data and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0): + video_grid_thw_lst = [] + pixel_values_videos_lst = [] + for item in mm_data.pop("videos", []): + video_array, metadata = item + + # FIXME(Isotr0py): Activate the below logic after we can disable + # resampling from video loader backend. + # assert metadata["total_num_frames"] == len(video_array), ( + # f"Total frames {metadata['total_num_frames']} does not " + # f"match the length of video array {len(video_array)}.") + + # NOTE: Temporary workaround for resampled videos. + # this can cause a divergence with HF implementation if + # the input video is resampled in advance. + + if metadata["total_num_frames"] != len(video_array): + logger.warning( + "Total frames in metadata " + "(%s) does not match the length of " + "video array %s. This can " + "be because the video is resampled " + "in advance. This may cause " + "a divergence with HF implementation.", + metadata["total_num_frames"], + len(video_array), + ) + metadata["total_num_frames"] = len(video_array) + metadata = VideoMetadata(**metadata) + + video_mm_data = dict() + video_mm_data["videos"] = [[video_array]] + video_mm_data["video_metadata"] = [[metadata]] + + video_outputs = super()._call_hf_processor( + prompt="<|begin_of_video|><|video|><|end_of_video|>", + mm_data=video_mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + input_ids = video_outputs.pop("input_ids") + input_ids[input_ids == processor.image_token_id] = ( + processor.video_token_id) + video_placeholder = processor.tokenizer.batch_decode( + input_ids)[0] + prompt = prompt.replace( + "<|begin_of_video|><|video|><|end_of_video|>", + video_placeholder, + ) + + grid_t = len(video_outputs["video_grid_thw"]) + _, grid_h, grid_w = video_outputs["video_grid_thw"][0] + grid_thw = torch.tensor([[grid_t, grid_h, grid_w]]) + + video_grid_thw_lst.append(grid_thw) + pixel_values_videos_lst.append( + video_outputs["pixel_values_videos"]) + video_outputs = dict( + pixel_values_videos=torch.cat(pixel_values_videos_lst), + video_grid_thw=torch.cat(video_grid_thw_lst), + ) + else: + video_outputs = dict() + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + combined_outputs = dict( + processed_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _qwen2vl_field_config(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + + boi_token_id = hf_config.image_start_token_id + eoi_token_id = hf_config.image_end_token_id + + bov_token_id = hf_config.video_start_token_id + eov_token_id = hf_config.video_end_token_id + + merge_length = image_processor.merge_size**2 + + def get_image_replacement_glm4v(item_idx: int): + grid_thw = out_mm_kwargs["image_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [hf_processor.image_token_id] * num_tokens + + def get_video_replacement_glm4v(item_idx: int): + grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + + video, metadata = mm_items["video"][item_idx] + timestamps = self.info._get_video_second_idx(metadata, len(video)) + frames_idx_token = [ + tokenizer.encode(str(i), add_special_tokens=False) + for i in timestamps + ] + num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + placeholder = [] + placeholder.append(bov_token_id) + for frame_idx in frames_idx_token: + placeholder.append(boi_token_id) + placeholder.extend([hf_processor.video_token_id] * + num_tokens_per_frame) + placeholder.append(eoi_token_id) + placeholder.extend(frame_idx) + placeholder.append(eov_token_id) + return placeholder + + return [ + PromptReplacement( + modality="image", + target=hf_processor.image_token, + replacement=get_image_replacement_glm4v, + ), + PromptReplacement( + modality="video", + target="<|begin_of_video|><|video|><|end_of_video|>", + replacement=get_video_replacement_glm4v, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Glm4vMultiModalProcessor, + info=Glm4vProcessingInfo, + dummy_inputs=Glm4vDummyInputsBuilder, +) +class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + "model.visual.": "visual.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|begin_of_image|><|image|><|end_of_image|>" + if modality.startswith("video"): + return "<|begin_of_video|><|video|><|end_of_video|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Glm4vConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Glm4vVisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-5), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, ""), + architectures=["Glm4ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of {name}. Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Glm4vImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Glm4vImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Glm4vImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Glm4vVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + if pixel_values_videos is None and video_embeds is None: + return None + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Glm4vVideoPixelInputs( + type="pixel_values_videos", + # video_metadata=video_metadata, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Glm4vVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]: + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + device = self.visual.device + flat_grid_thw = torch.cat([ + torch.tensor([[1, h, w]] * t, device=device) + for t, h, w in grid_thw + ]) + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, + grid_thw=flat_grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if (input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality): + mm_input_by_modality["image"] = ( + self._parse_and_validate_image_input(**kwargs)) + if (input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality): + mm_input_by_modality["video"] = ( + self._parse_and_validate_video_input(**kwargs)) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if (multimodal_embeddings is not None + and len(multimodal_embeddings) != 0 + and all(embed.numel() > 0 for embed in multimodal_embeddings)): + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id], + ) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Glm4vImageInputs] = None, + video_input: Optional[Glm4vVideoInputs] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for GLM-4V. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for GLM-4V + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + second_per_grid_ts: Tensor `(num_videos)` of video time interval ( + in seconds) for each grid along the temporal dimension in the + 3D position IDs. `None` if no videos are passed. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.merger.", + tower_model="visual.", + ) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py new file mode 100644 index 0000000..bf6696b --- /dev/null +++ b/vllm/model_executor/models/glm4_moe.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4.5 model compatible with HuggingFace weights.""" +import typing +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Glm4MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Glm4MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + # noaux_tc is not set in transformers new config now + self.gate.e_score_correction_bias = (nn.Parameter( + torch.empty(config.n_routed_experts))) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Glm4MoeAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 131072, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + use_qk_norm: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_qk_norm = use_qk_norm + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q = self.q_norm(q.reshape(-1, self.num_heads, + self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, + self.head_dim)).reshape(k.shape) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + + self.self_attn = Glm4MoeAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + head_dim=config.head_dim, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=config.attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_qk_norm=config.use_qk_norm, + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace): + self.mlp = Glm4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Glm4MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Glm4MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Glm4MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if f"layers.{layer_idx+i}." in weight_name: + return layer_idx + i + return None \ No newline at end of file diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py new file mode 100644 index 0000000..5d1b23a --- /dev/null +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name +from .interfaces import SupportsPP +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class Glm4MoeMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = Glm4MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class Glm4MoeMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + Glm4MoeMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class Glm4MoeMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name \ No newline at end of file diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py new file mode 100644 index 0000000..864aba1 --- /dev/null +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/THUDM/GLM-4 +"""Inference-only GLM-4v model visual encoder compatible with THUDM weights.""" +from argparse import Namespace +from typing import Optional + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class PatchEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + images = images.to(device=self.proj.weight.device, + dtype=self.proj.weight.dtype) + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class Attention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = config.num_heads // self.tp_size + self.head_dim = config.hidden_size // config.num_heads + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_dim, + config.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, + self.scale) + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + + out = self.attn(q, k, v) + output, _ = self.dense(out) + output = self.output_dropout(output) + return output + + +class MLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation_fn(x) + x, _ = self.fc2(x) + return x + + +class TransformerLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.mlp = MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class Transformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.layers = nn.ModuleList([ + TransformerLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + + def __init__( + self, + config, + in_features, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + """ + The original implementation is the same as: + ```python + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + ``` + ``` + gate_proj_output, _ = self.gate_proj(x) + dense_h_to_4h_output, _ = self.dense_h_to_4h(x) + x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) + ``` + + We merge two ColumnParallelLinear into one MergedColumnParallelLinear: + ``` + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config + ) + ``` + ``` + x, _ = self.merged_proj(x) + ``` + """ + super().__init__() + self.linear_proj = ReplicatedLinear(in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = SiluAndMul() + + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.merged_proj") + + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h") + + def forward(self, x): + x, _ = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x, _ = self.merged_proj(x) + x = self.act2(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = PatchEmbedding(vision_config) + self.transformer = Transformer(vision_config, + quant_config=quant_config, + prefix=f"{prefix}.transformer") + self.linear_proj = GLU(config, + in_features=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x \ No newline at end of file diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py new file mode 100644 index 0000000..7584b51 --- /dev/null +++ b/vllm/model_executor/models/glm4v.py @@ -0,0 +1,657 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/THUDM/CogAgent +"""Inference-only CogAgent model compatible with THUDM weights.""" +from argparse import Namespace +from collections.abc import Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import BatchFeature, PreTrainedTokenizer, TensorType +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput + +from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import ChatGLMConfig + +from .chatglm import ChatGLMBaseModel, ChatGLMModel +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import flatten_bn, merge_multimodal_embeddings + + +class GLMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + +class EVA2CLIPPatchEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + images = images.to(device=self.proj.weight.device, + dtype=self.proj.weight.dtype) + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class EVA2CLIPAttention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = config.num_heads // self.tp_size + self.head_dim = config.hidden_size // config.num_heads + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_dim, + config.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, + self.scale) + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + + out = self.attn(q, k, v) + output, _ = self.dense(out) + output = self.output_dropout(output) + return output + + +class EVA2CLIPMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation_fn(x) + x, _ = self.fc2(x) + return x + + +class EVA2CLIPTransformerLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = EVA2CLIPAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.mlp = EVA2CLIPMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class EVA2CLIPTransformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + self.layers = nn.ModuleList([ + EVA2CLIPTransformerLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class EVA2CLIPGLU(nn.Module): + + def __init__( + self, + config, + in_features, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + """ + The original implementation is the same as: + ```python + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + ``` + ``` + gate_proj_output, _ = self.gate_proj(x) + dense_h_to_4h_output, _ = self.dense_h_to_4h(x) + x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) + ``` + + We merge two ColumnParallelLinear into one MergedColumnParallelLinear: + ``` + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config + ) + ``` + ``` + x, _ = self.merged_proj(x) + ``` + """ + super().__init__() + self.linear_proj = ReplicatedLinear(in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = SiluAndMul() + + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.merged_proj") + + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.dense_4h_to_h") + + def forward(self, x): + x, _ = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x, _ = self.merged_proj(x) + x = self.act2(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', + ): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config) + self.transformer = EVA2CLIPTransformer(vision_config, + quant_config=quant_config, + prefix=f"{prefix}.transformer") + self.linear_proj = EVA2CLIPGLU(config, + in_features=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj") + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x + + +class GLM4VModel(ChatGLMModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + quant_config = vllm_config.quant_config + + self.vision = EVA2CLIPModel(self.config, + quant_config, + prefix=f"{prefix}.vision") + + +class GLM4VProcessor: + """ + This model doesn't define its own HF processor, + so we implement our own one here. + """ + + def __init__( + self, + config: ChatGLMConfig, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + vision_config = config.vision_config + image_size = vision_config["image_size"] + + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ]) + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + images: Optional[Union[ImageInput, list[ImageInput]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + text_inputs = self.tokenizer(text) + + if len(images) == 0: + image_inputs = {} + else: + pixel_values = [self.image_transform(image) for image in images] + image_inputs = {"pixel_values": torch.stack(pixel_values)} + + return BatchFeature( + { + **text_inputs, + **image_inputs, + }, + tensor_type=return_tensors, + ) + + +class GLM4VProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(ChatGLMConfig) + + def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: + return self.ctx.init_processor( + GLM4VProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + grid_length = image_size // patch_size // 2 + return grid_length * grid_length + + def get_num_image_feature_tokens(self) -> int: + # EVA2CLIPModel has embeddings for boi and eoi tokens as well + return self.get_num_image_tokens() + 2 + + +class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>" + + return base_text * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + + target_width = target_height = vision_config["image_size"] + num_images = mm_counts.get("image", 0) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + + boi_token_id = hf_config.boi_token_id + image_token_id = hf_config.pad_token_id + eoi_token_id = hf_config.eoi_token_id + + def get_replacement(item_idx: int): + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [image_token_id] * num_image_tokens + + return [boi_token_id] + image_tokens + [eoi_token_id] + + return [ + PromptReplacement( + modality="image", + target=[boi_token_id, image_token_id, eoi_token_id], + replacement=get_replacement, + ), + ] + + +@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder) +class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsMultiModal): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"], + "merged_proj": ["gate_proj", "dense_h_to_4h"] + } + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.encoder", + connector="transformer.vision.linear_proj", + tower_model="transformer.vision.transformer") + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|begin_of_image|><|endoftext|><|end_of_image|>" + + raise ValueError("Only image modality is supported") + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + transformer_type: type[GLM4VModel] = GLM4VModel, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + transformer_type=transformer_type, + ) + + self.transformer: GLM4VModel + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config["image_size"] + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return GLMVImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + return None + + def _process_image_input( + self, image_input: GLMVImagePixelInputs) -> torch.Tensor: + pixel_values = image_input["data"].to(dtype=self.config.torch_dtype) + + return self.transformer.vision(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.transformer + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.transformer.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + placeholder_token_id=[ + self.config.boi_token_id, + self.config.pad_token_id, + self.config.eoi_token_id, + ], + ) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + + return hidden_states diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py new file mode 100644 index 0000000..2702155 --- /dev/null +++ b/vllm/model_executor/models/gpt2.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import GPT2Config + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from ..layers.pooler import Pooler, PoolingType +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GPT2Attention(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_attn", + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPT2MLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPT2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_fc", + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn(hidden_states=hidden_states) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +@support_torch_compile +class GPT2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.wte") + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPT2Block( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class GPT2LMHeadModel(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.transformer = GPT2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head") + if self.config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.transformer.wte) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) + + +class GPT2ForSequenceClassification(nn.Module): + """GPT2 Model for sequence classification. + + This class expands GPT2Model with pooling and score functions - last token + is being used for classification. + + Attributes: + transformer: An instance of GPT2Model used for forward operations. + score: A layer for calculating logits. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.transformer = GPT2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "gpt2")) + self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.transformer( + input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + logits = self.score(hidden_states) + return logits + + +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.') and not name.startswith( + "lm_head"): + name = 'transformer.' + name + yield name, tensor diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py new file mode 100644 index 0000000..661a67b --- /dev/null +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import GPTBigCodeConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class GPTBigCodeAttention(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + self.tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.multi_query = config.multi_query + if self.multi_query: + total_num_kv_heads = 1 + self.num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + self.num_kv_heads = self.num_heads + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], + dim=-1, + ) + attn_output = self.attn(q, k, v) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn(hidden_states=hidden_states, ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +@support_torch_compile +class GPTBigCodeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPTBigCodeBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = inputs_embeds + self.wpe(position_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.transformer = GPTBigCodeModel(vllm_config=vllm_config, + prefix=prefix) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + self.transformer.vocab_size, + self.transformer.embed_dim, + org_num_embeddings=self.config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + skip_prefixes = None + if self.config.tie_word_embeddings: + skip_prefixes = ["lm_head."] + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py new file mode 100644 index 0000000..bd162a5 --- /dev/null +++ b/vllm/model_executor/models/gpt_j.py @@ -0,0 +1,339 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py +# Copyright 2023 The vLLM team. +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-J model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn +from transformers import GPTJConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GPTJAttention(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + scaling = self.head_size**-0.5 + assert getattr(config, "rotary", True) + assert config.rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=config.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class GPTJMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTJConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear( + hidden_size, + intermediate_size, + quant_config=quant_config, + ) + self.fc_out = RowParallelLinear( + intermediate_size, + hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.activation_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + inner_dim = (4 * config.n_embd + if config.n_inner is None else config.n_inner) + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.mlp = GPTJMLP(inner_dim, config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + ) + mlp_output = self.mlp(hidden_states) + hidden_states = attn_output + mlp_output + residual + return hidden_states + + +@support_torch_compile +class GPTJModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.embed_dim = config.n_embd + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.start_layer, self.end_layer, self.h = make_layers( + config.n_layer, + lambda prefix: GPTJBlock( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.h[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "attn.bias" in name or "attn.masked_bias" in name: + continue + + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class GPTJForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + assert not config.tie_word_embeddings + self.transformer = GPTJModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "transformer")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.n_embd, + bias=True, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py new file mode 100644 index 0000000..0f46d75 --- /dev/null +++ b/vllm/model_executor/models/gpt_neox.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import os +import re +import torch +from torch import nn +from transformers import GPTNeoXConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) +from vllm import _custom_ops as ops + +class GPTNeoXAttention(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + self.bias = getattr(config, "attention_bias", True) + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=self.bias, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=self.bias, + quant_config=quant_config, + ) + scaling = self.head_size**-0.5 + rotary_dim = int(self.head_size * config.rotary_pct) + assert rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.dense(attn_output) + return output + + +class GPTNeoXMLP(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.hidden_act) + + def forward(self, hidden_states): + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.mlp = GPTNeoXMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + attn_input = self.input_layernorm(hidden_states) + attn_output = self.attention( + position_ids=position_ids, + hidden_states=attn_input, + ) + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_input = self.post_attention_layernorm(attn_output) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + return hidden_states + + +@support_torch_compile +class GPTNeoXModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_in = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GPTNeoXLayer( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + self.quant_method = None + if quant_config is not None: + self.quant_method=quant_config.get_name() + self.quant_config=quant_config + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_in(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(position_ids, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ("attention.bias" in name or "attention.masked_bias" in name + or "rotary_emb.inv_freq" in name): + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using OpenRLHF may include + # these tensors in the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class GPTNeoXForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "gpt_neox")) + self.embed_out = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.embed_out.weight = self.gpt_neox.embed_in.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.gpt_neox.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.gpt_neox.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.gpt_neox(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.embed_out, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py new file mode 100644 index 0000000..bd4d5d0 --- /dev/null +++ b/vllm/model_executor/models/granite.py @@ -0,0 +1,493 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import GraniteConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_layers, maybe_prefix) + + +class GraniteMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +@support_torch_compile +class GraniteModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + hidden_states *= self.config.embedding_multiplier + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = GraniteModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + if hasattr(config, "logits_scaling"): + logit_scale /= config.logits_scaling + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=logit_scale) + else: + self.lm_head = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + skip_prefixes = (["lm_head."] + if self.config.tie_word_embeddings else None) + + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py new file mode 100644 index 0000000..6c7c9f5 --- /dev/null +++ b/vllm/model_executor/models/granite_speech.py @@ -0,0 +1,790 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2025 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite speech model.""" +import math +from collections.abc import Iterable, Mapping +from typing import Optional, TypedDict, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import BatchFeature, PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .blip2 import Blip2QFormerModel +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, embed_multimodal, + init_vllm_registered_model, maybe_prefix) + + +### Audio Input +class GraniteSpeechAudioInputs(TypedDict): + + input_features: torch.Tensor + """Shape: `(bsz, num_features, 160)`""" + + input_features_mask: torch.Tensor + """Shape: `(bsz, num_features)`""" + + audio_embed_sizes: list[int] + """List of length `bsz`""" + + +class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 1} + + # There is no limit to the maximum number of audio tokens that can be + # encoded as features; we pick ~5000 as a number that is probably higher + # than we would expect to encounter. The sequence of length + # get_max_audio_len() produces get_max_audio_tokens(). + def get_max_audio_tokens(self): + return 5001 + + def get_max_audio_len(self): + return 8000000 + + +### Input Processing & Multimodal utils +class GraniteSpeechMultiModalProcessor( + BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_hf_processor().audio_processor + sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] + return MultiModalDataParser(target_sr=sampling_rate) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + audio_embed_sizes=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + feature_extractor = processor.audio_processor + vocab = tokenizer.get_vocab() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|audio|>") + audio_token_id = vocab[audio_token] + + def get_replacement(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + audio_length = audio.shape[-1] + num_projector_features = feature_extractor._get_num_audio_features( + [audio_length])[0] + return [audio_token_id] * num_projector_features + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_replacement, + ) + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + if audios: + # GraniteSpeechFeatureExtractor accepts "audio" + mm_data["audio"] = audios + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + if "audio" in mm_data: + # Calculate the number of audio tokens per entry in the batch; + # This is used to split the batch back out after padding. + audio_token_index = self.info.get_hf_config().audio_token_index + processed_outputs["audio_embed_sizes"] = [ + torch.sum(indices == audio_token_index).item() + for indices in processed_outputs["input_ids"] + ] + + return processed_outputs + + +class GraniteSpeechDummyInputsBuilder( + BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + return { + "audio": + self._get_dummy_audios( + length=self.info.get_max_audio_len(), + num_audios=num_audios, + ) + } + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + audio_token = getattr(hf_processor, "audio_token", "<|audio|>") + return audio_token * num_audios + + +### QFormer Projector +class GraniteSpeechEncoderProjector(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: CacheConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.projector_config.hidden_size + self.downsample_rate = config.downsample_rate + self.window_size = config.window_size + self.num_queries = config.window_size // config.downsample_rate + + self.query = nn.Parameter( + torch.zeros(1, self.num_queries, + config.projector_config.hidden_size)) + + # NOTE - this is implemented generically in transformers, + # but for now we create the QFormer model directly since + # all existing models use this for the projector. + self.qformer = Blip2QFormerModel( + config.projector_config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.qformer", + ) + self.linear = nn.Linear(config.projector_config.hidden_size, + config.text_config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = hidden_states.size() + nblocks = math.ceil(seq_len / self.window_size) + pad = nblocks * self.window_size - seq_len + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), + "constant", 0) + hidden_states = hidden_states.view(batch_size * nblocks, + self.window_size, dim) + + last_hidden_state = self.qformer( + query_embeds=self.query.data, + encoder_hidden_states=hidden_states, + ) + + query_proj = self.linear( + last_hidden_state.view( + batch_size, + nblocks * self.window_size // self.downsample_rate, + -1, + )) + return query_proj + + +# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git +# NOTE - it would be nice to see if we can align this with other models using +# conformer in vLLM, e.g., phi4mm audio. +class GraniteSpeechConformerFeedForward(nn.Module): + """Feedforward module for conformer encoder blocks.""" + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + + self.up_proj = ColumnParallelLinear( + input_size=config.hidden_dim, + output_size=config.hidden_dim * config.feedforward_mult, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.silu = nn.SiLU() + + self.down_proj = RowParallelLinear( + input_size=config.hidden_dim * config.feedforward_mult, + output_size=config.hidden_dim, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + hidden_states, _ = self.up_proj(hidden_states) + hidden_states = self.silu(hidden_states) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class GraniteSpeechConformerAttention(nn.Module): + """Attention for conformer blocks using Shaw's relative positional + embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155) + for more details. + """ + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = self.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, + self.dim_head) + + if self.context_size <= 0 or self.context_size > self.max_pos_emb: + raise ValueError( + "Context size is either less than 0 or exceeds the max_pos_emb" + ) + + def forward(self, hidden_states: torch.Tensor, + attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = self.pre_norm(hidden_states) + bsz, num_features, _ = hidden_states.shape + + num_blocks = math.ceil(num_features / self.context_size) + remainder = num_features % self.context_size + if remainder > 0: + # right padding to reach block size + hidden_states = torch.nn.functional.pad( + hidden_states, (0, 0, 0, self.context_size - remainder)) + + # NOTE: would be nice to try to use qkvparallellinear + # here for this block attention implementation if possible + query_states = self.to_q(hidden_states) + key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) + + query_states = query_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, + -1).transpose(2, 3) + key_states = key_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, -1).transpose(2, 3) + value_states = value_states.reshape(bsz, num_blocks, self.context_size, + self.num_heads, + -1).transpose(2, 3) + + # shaw's relative positional embedding + dist = attention_dists.to(hidden_states.device) + rel_pos_emb = self.rel_pos_emb(dist) + rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + + list(rel_pos_emb.shape)) + pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, + dim=-1) * self.scale + + if remainder > 0: + # masked attention in the extended block + mask = torch.ones(self.context_size, + self.context_size, + dtype=bool, + device=hidden_states.device) + mask[:remainder, :remainder] = 0 + mask_value = -torch.finfo(pos_attn.dtype).max + pos_attn[:, -1, :].masked_fill_(mask, mask_value) + + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention(query_states, + key_states, + value_states, + attn_mask=pos_attn, + scale=self.scale) + out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) + return self.to_out(out[:, :num_features, :]) + + +class GraniteSpeechConformerDepthWiseConv1d(nn.Module): + """Wrapper for padded 1D pointwise convolution.""" + + def __init__(self, + chan_in: int, + chan_out: int, + kernel_size: int, + prefix: str = ""): + super().__init__() + # Padding for the 1D conv is symmetric or close (i.e., offset by one). + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + + self.conv = nn.Conv1d(chan_in, + chan_out, + kernel_size, + groups=chan_in, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.padding) + return self.conv(hidden_states) + + +class GraniteSpeechConformerConvModule(nn.Module): + """Conformer conv module consisting of several 1D/depthwise 1D + convolutional layers. + """ + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.glu = nn.GLU(dim=1) + self.depth_conv = GraniteSpeechConformerDepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=config.conv_kernel_size, + prefix=f"{prefix}.depth_conv", + ) + self.silu = nn.SiLU() + self.batch_norm = nn.BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(hidden_states) + hidden_states = self.up_conv(hidden_states.permute(0, 2, 1)) + hidden_states = self.glu(hidden_states) + hidden_states = self.depth_conv(hidden_states) + hidden_states = self.silu(self.batch_norm(hidden_states)) + hidden_states = self.down_conv(hidden_states).permute(0, 2, 1) + return hidden_states + + +class GraniteSpeechConformerBlock(nn.Module): + """Conformer block, consisting largely of linear layers, + attention, and convolutional layers.""" + + def __init__(self, config: PretrainedConfig, prefix: str = ""): + super().__init__() + self.ff1 = GraniteSpeechConformerFeedForward(config, + prefix=f"{prefix}.ff1") + self.attn = GraniteSpeechConformerAttention(config, + prefix=f"{prefix}.attn") + self.conv = GraniteSpeechConformerConvModule(config, + prefix=f"{prefix}.conv") + self.ff2 = GraniteSpeechConformerFeedForward(config, + prefix=f"{prefix}.ff2") + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def forward(self, hidden_states: torch.Tensor, + attention_dists: torch.Tensor) -> torch.Tensor: + hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states + hidden_states = self.attn( + hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = self.conv(hidden_states) + hidden_states + hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states + hidden_states = self.post_norm(hidden_states) + return hidden_states + + +class GraniteSpeechCTCEncoder(nn.Module): + """CTC Encoder comprising conformer blocks and additional linear layers.""" + + def __init__(self, + config: PretrainedConfig, + prefix: str, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + # Precompute clamped relative positional encoding distances + seq = torch.arange(config.context_size) + relpos_dist = seq.view(-1, 1) - seq.view(1, -1) + self.attention_dists = torch.clamp( + relpos_dist, -config.context_size, + config.context_size) + config.max_pos_emb + + self.input_linear = nn.Linear(config.input_dim, + config.hidden_dim, + bias=True) + self.layers = nn.ModuleList([ + GraniteSpeechConformerBlock( + config, + prefix=f"{prefix}.layers.{idx}", + ) for idx in range(config.num_layers) + ]) + + self.out = ColumnParallelLinear( + input_size=config.hidden_dim, + output_size=config.output_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out", + ) + + self.out_mid = RowParallelLinear( + input_size=config.output_dim, + output_size=config.hidden_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_mid", + ) + self.softmax = nn.Softmax(dim=-1) + self.num_layers = config.num_layers + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.input_linear(hidden_states) + for idx, layer in enumerate(self.layers, start=1): + hidden_states = layer(hidden_states, + attention_dists=self.attention_dists) + + if idx == self.num_layers // 2: + hidden_states_mid = hidden_states.clone() + hidden_states_mid, _ = self.out(hidden_states_mid) + hidden_states_mid = self.softmax(hidden_states_mid) + hidden_states_mid, _ = self.out_mid(hidden_states_mid) + hidden_states += hidden_states_mid + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + GraniteSpeechMultiModalProcessor, + info=GraniteSpeechMultiModalProcessingInfo, + dummy_inputs=GraniteSpeechDummyInputsBuilder) +class GraniteSpeechForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, +): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("audio"): + return "<|audio|>" + + raise ValueError("Only audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + + self.config = config + self.quant_config = quant_config + self.cache_config = cache_config + self.sampler = get_sampler() + + # The language model is typically a Granite LLM + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + # Conformer encoder + self.encoder = GraniteSpeechCTCEncoder( + config=config.encoder_config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + + # Blip2 QFormer + self.projector = GraniteSpeechEncoderProjector( + config=config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.projector", + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_audio_input( + self, + **kwargs: object, + ) -> Optional[GraniteSpeechAudioInputs]: + input_features = kwargs.pop("input_features", None) + input_features_mask = kwargs.pop("input_features_mask", None) + audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) + if input_features is None: + return None + + # If we have a batch of variable feature length audio clips, we need + # to mask the features; usually we would get an input_features_mask + # from the processor, but we handle rebuilding it here since + # vLLM generally processes everything independently + batches. + if input_features_mask is None: + input_features_mask = self._build_input_features_mask( + audio_embed_sizes) + + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio input features. " + f"Got type: {type(input_features)}") + + if input_features_mask is not None and not isinstance( + input_features_mask, torch.Tensor): + raise ValueError("Incorrect type of audio input features mask. " + f"Got type: {type(input_features_mask)}") + + if isinstance(input_features, torch.Tensor): + # Granite speech currently only allows one audio token per instance + # and features are already unsqueezed in the processor, so one + # instance will have shape [1, {num_features}, 160]. As such, + # input features will usually be of shape + # [bsz, 1, num_features, 160], which we squeeze to be 3D here. + if len(input_features.shape) == 4: + input_features = input_features.squeeze(1) + if len(input_features.shape) != 3: + raise ValueError( + "Squeezed input features should be 3D but are of shape " + f"{input_features.shape}") + input_features = input_features.to( + self.encoder.input_linear.weight.dtype) + + else: + # Otherwise we have a list of tensors, which are almost certainly + # differing in their respective numbers of audio features; + # stack them into a 3D tensor of size [bsz, most_num_features, 160]. + input_features = self._pad_and_stack_input_features( + input_features, ).to(self.encoder.input_linear.weight.dtype) + + return GraniteSpeechAudioInputs( + input_features=input_features, + input_features_mask=input_features_mask, + audio_embed_sizes=audio_embed_sizes.flatten().tolist(), + ) + + def _build_input_features_mask( + self, + audio_embed_sizes: torch.Tensor, + ) -> torch.Tensor: + """Calculate the input features mask, which will generally be used + to mask the padded features for all entries in the batch except + for those with the most audio features. + + Args: + audio_embed_sizes: torch.Tensor + Tensor of num features in each seq in the batch. + Returns: + torch.Tensor: Mask of shape (bsz, num_features) to be applied to + the audio features prior to splitting the audio embeddings. + """ + most_audio_features = torch.max(audio_embed_sizes).item() + mask_indices = torch.arange( + most_audio_features, + device=audio_embed_sizes.device, + ).view(1, -1) + input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1) + return input_features_mask + + def _pad_and_stack_input_features( + self, + input_features: list[torch.Tensor], + ) -> torch.Tensor: + """Given a list of input features of varying length, pad them to the + same length and stack them into a torch.Tensor. + + NOTE: Usually, padding is done in the input processor/feature extractor + and zero padded prior to the computation of the Mel features; the + resulting values are only constant within a batch and generally nonzero + (i.e., slightly negative nums); we should validate that this is okay + since we don't use a feature attention mask, but the more important + thing is that we apply the input_features_mask with variable len + batches. + + Args: + input_features: list[torch.Tensor] + Input features to be coerced into a tensor. + Returns: + torch.Tensor: Tensor of shape [bsz, num_features, 160], where + num_features is the max number of features of any entry in the + batch. + """ + # Input features are of shape [bsz, num_features, 160] + feat_lens = [feats.shape[1] for feats in input_features] + padding = [max(feat_lens) - length for length in feat_lens] + # TODO (Alex) - Validate that it's okay to zero pad like this; + # in transformers we zero pad prior to calculating the speech features, + # so the value is not zero and is dependent on the batched features. + padded = [ + torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0)) + for feats, pad in zip(input_features, padding) + ] + stacked_features = torch.cat(padded, dim=0).to(input_features[0]) + return stacked_features + + def _process_audio_input( + self, + audio_input: GraniteSpeechAudioInputs, + ) -> tuple[torch.Tensor]: + """Compute the audio features to be merged into the LLM embeddings. + + Args: + audio_input: GraniteSpeechAudioInputs + Audio inputs object containing Mel features, an input features + mask, and the (flattened) number of audio tokens per instance. + Returns: + tuple[torch.Tensor]: List of length bsz. + """ + # TODO (Alex) - support embedding inputs + encoder_embeds = self.encoder(audio_input["input_features"]) + # [bsz, , 4096] + projected_embeds = self.projector(encoder_embeds) + # Apply mask on variable length audio features + masked_embeds = projected_embeds[audio_input["input_features_mask"]] + # Split variable length features into a tuple + return torch.split(masked_embeds, audio_input["audio_embed_sizes"]) + + def get_multimodal_embeddings( + self, + **kwargs: object, + ) -> MultiModalEmbeddings: + """Compute the audio embeddings if audio inputs are present.""" + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return [] + return None + audio_features = self._process_audio_input(audio_input) + return audio_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + """Compute the merged LLM / audio embeddings.""" + if multimodal_embeddings is None \ + or len(multimodal_embeddings) == 0: + return self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = embed_multimodal( + input_ids, + self.config.audio_token_index, + self.language_model.model.get_input_embeddings, + multimodal_embeddings, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + audio_embeds = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds) + input_ids = None + + model_output = self.language_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits( + hidden_states, + sampling_metadata, + ) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix in multimodal models.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="projector", + tower_model="encoder", + ) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py new file mode 100644 index 0000000..5a70f3a --- /dev/null +++ b/vllm/model_executor/models/granitemoe.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GraniteMoe model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers.models.granitemoe import GraniteMoeConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .interfaces import SupportsLoRA, SupportsPP +from .utils import AutoWeightsLoader, make_layers, maybe_prefix + + +class GraniteMoeMoE(nn.Module): + """A tensor-parallel MoE implementation for GraniteMoe that shards each + expert across all ranks. + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class GraniteMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attention_multiplier: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = (attention_multiplier if attention_multiplier + is not None else self.head_dim**-1) + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +@support_torch_compile +class GraniteMoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config # Required by MixtralModel + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + else: + new_weights[n] = p + return mixtral.MixtralModel.load_weights(self, new_weights.items()) + + +class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.model = GraniteMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py new file mode 100644 index 0000000..676ef24 --- /dev/null +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -0,0 +1,653 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only GraniteMoeHybrid model.""" +# Added by the IBM Team, 2025 +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers import GraniteMoeHybridConfig + +from vllm import envs +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.mamba2_metadata import ( + Mamba2Metadata, prepare_mamba2_metadata) +from vllm.model_executor.layers.mamba.mamba_mixer2 import ( + MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType + +from .granitemoe import GraniteMoeMoE +from .granitemoeshared import GraniteMoeSharedMLP +from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, + SupportsQuant) +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class GraniteMoeHybridMambaDecoderLayer(nn.Module): + + def __init__(self, + config: GraniteMoeHybridConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + + self.mamba = MambaMixer2(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + chunk_size=config.mamba_chunk_size) + + self.block_sparse_moe = None + if getattr(config, "num_local_experts", 0) > 0: + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + **kwargs, + ): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.mamba(hidden_states, mamba_cache_params, + mamba2_metadata) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + if self.block_sparse_moe is not None: + hidden_states = self.block_sparse_moe(hidden_states) + # else: skip + else: + # create a copy since block_sparse_moe modifies in-place + if self.block_sparse_moe is not None: + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp( + hidden_states) + del moe_hidden_states + else: + hidden_states = self.shared_mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states, residual + + +class GraniteMoeHybridAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeHybridConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + + self.self_attn = GraniteMoeHybridAttention( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + + self.block_sparse_moe = None + if getattr(config, "num_local_experts", 0) > 0: + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + if self.block_sparse_moe is not None: + hidden_states = self.block_sparse_moe(hidden_states) + # else: skip + else: + # create a copy since block_sparse_moe modifies in-place + if self.block_sparse_moe is not None: + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp( + hidden_states) + del moe_hidden_states + else: + hidden_states = self.shared_mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states, residual + + +class GraniteMoeHybridAttention(nn.Module): + + def __init__( + self, + config: GraniteMoeHybridConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.causal = True + self.hidden_size = config.hidden_size + self.attention_bias = config.attention_bias + self.attention_multiplier = config.attention_multiplier + self.total_num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads + + # TensorParallel logic + tp_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size) + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if config.position_embedding_type == "rope": + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=int(config.rope_theta), + rope_scaling=config.rope_scaling \ + if hasattr(config, "rope_scaling") \ + and config.rope_scaling is not None else None, + is_neox_style=True, + ) + else: + self.rotary_emb = None + + self.attn = Attention(self.num_heads, + self.head_dim, + self.attention_multiplier, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + query, key, value = qkv.split([ + self.num_heads * self.head_dim, self.num_key_value_heads * + self.head_dim, self.num_key_value_heads * self.head_dim + ], + dim=-1) + + if self.rotary_emb is not None: + query, key = self.rotary_emb(positions, query, key) + + hidden_states = self.attn(query, key, value) + del query, key, value + + hidden_states = self.o_proj(hidden_states)[0] + return hidden_states + + +ALL_DECODER_LAYER_TYPES = { + "attention": GraniteMoeHybridAttentionDecoderLayer, + "mamba": GraniteMoeHybridMambaDecoderLayer, +} + + +class GraniteMoeHybridModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layer_types[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + attn_metadata = get_forward_context().attn_metadata + + if not envs.VLLM_USE_V1: + mamba2_metadata = prepare_mamba2_metadata( + chunk_size=self.config.mamba_chunk_size, + attn_metadata=attn_metadata, + ) + else: + # v1 get mamba2_metadata from forward_context + mamba2_metadata = None + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states = hidden_states * self.embedding_multiplier + residual = None + else: + if intermediate_tensors is None: + raise RuntimeError('Intermediate tensors may not be None!') + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + num_attn = 0 + for i in range(len(self.layers)): + layer = self.layers[i] + if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): + num_attn += 1 + + layer_mamba_cache_params = None + if isinstance( + layer, + GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params: + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + i - num_attn) + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + mamba_cache_params=layer_mamba_cache_params, + mamba2_metadata=mamba2_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + def _load(n, p): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, p) + loaded_params.add(n) + + def _load_shard(n, p, shard_id): + # Skip layers on other devices. + if not is_pp_missing_parameter(n, self): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, p, shard_id) + loaded_params.add(n) + + def _load_expert(n, p, name, shard_id, expert_id): + param = params_dict[n] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, + p, + name, + shard_id=shard_id, + expert_id=expert_id) + loaded_params.add(n) + + for n, p in weights: + if "A_log" in n: + n = n.replace("A_log", "A") + + # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215 + # Mapping different experts' layout: + # from HF (input_linear, output_linear, router) + # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + _load_expert(n.replace('.input_linear.', '.experts.w13_'), + w1_param, + w1_name, + shard_id='w1', + expert_id=e) + _load_expert(n.replace('.input_linear.', '.experts.w13_'), + w3_param, + w3_name, + shard_id='w3', + expert_id=e) + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + _load_expert(n.replace('.output_linear.', '.experts.w2_'), + w2_param, + w2_name, + shard_id='w2', + expert_id=e) + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + _load(gate_name, p) + else: + loaded = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in n: + _load_shard(n.replace(weight_name, param_name), + p, + shard_id=shard_id) + loaded = True + if not loaded: + _load(n, p) + + return loaded_params + + +class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, + SupportsPP, IsHybrid, SupportsQuant): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + if cache_config.enable_prefix_caching: + raise RuntimeError( + "GraniteMoeHybrid currently does not support prefix caching") + + self.quant_config = vllm_config.quant_config + self.config = config + self.scheduler_config = scheduler_config + self.model = GraniteMoeHybridModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs): + + mamba_cache_params = None + if not envs.VLLM_USE_V1: + if self.mamba_cache is None: + num_mamba_layers = ( + self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, + LayerBlockType.mamba)) + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.model_config.dtype, + num_mamba_layers, *self._get_mamba_cache_shape()) + + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + hidden_states = self.model(input_ids, positions, mamba_cache_params, + intermediate_tensors, inputs_embeds) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> tuple[tuple[int, int], tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + + conv_state_shape, temporal_state_shape = None, None + + intermediate_size = self.config.mamba_expand * hidden_size + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( + self.config.mamba_n_groups, world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + + 2 * n_groups * self.config.mamba_d_state) + conv_state_shape = ( + divide(conv_dim, world_size), + self.config.mamba_d_conv - 1, + ) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, d_head, d_state) = (128, 64, 128) + temporal_state_shape = ( + divide(self.config.mamba_n_heads, world_size), + self.config.mamba_d_head, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py new file mode 100644 index 0000000..bb160db --- /dev/null +++ b/vllm/model_executor/models/granitemoeshared.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only GraniteMoeShared model. + +The architecture is the same as granitemoe but with the addition of shared +experts. +""" +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers.models.granitemoeshared import GraniteMoeSharedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .granitemoe import GraniteMoeAttention, GraniteMoeMoE +from .interfaces import SupportsLoRA, SupportsPP +from .utils import AutoWeightsLoader, make_layers, maybe_prefix + + +class GraniteMoeSharedMLP(nn.Module): + + def __init__( + self, + config: GraniteMoeSharedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.input_linear = MergedColumnParallelLinear( + input_size=self.input_size, + output_sizes=[self.hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.input_linear") + self.output_linear = RowParallelLinear( + self.hidden_size, + self.input_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_linear") + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.input_linear(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.output_linear(hidden_states) + return hidden_states + + +class GraniteMoeSharedDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeSharedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + hidden_states = self.block_sparse_moe(hidden_states) + else: + # create a copy since block_sparse_moe modifies in-place + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + del moe_hidden_states + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +@support_torch_compile +class GraniteMoeSharedModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config # Required by MixtralModel + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeSharedDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + else: + new_weights[n] = p + return mixtral.MixtralModel.load_weights(self, new_weights.items()) + + +class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.model = GraniteMoeSharedModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py new file mode 100644 index 0000000..4273afb --- /dev/null +++ b/vllm/model_executor/models/gritlm.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from array import array +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.config import ModelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolerHead +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config + +from .interfaces import SupportsV0Only + +logger = init_logger(__name__) + + +class GritLMPooler(nn.Module): + + def __init__(self, model_config: ModelConfig): + super().__init__() + + self.model_config = model_config + + tokenizer = cached_tokenizer_from_config(self.model_config) + + # Collect the tokens needed for pattern matching. + # "▁<" is different from "_<". The former uses "▁" to indicate that + # the next token is the start of a word. + # "<0x0A>" is the newline token (i.e. "\n")." + self.token_ids = { + tok: tokenizer.convert_tokens_to_ids([tok])[0] + for tok in ["", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] + } + + def tokens_to_ids(tokens: list[str]) -> array: + return array("i", [self.token_ids[token] for token in tokens]) + + self.user_pattern_ids = tokens_to_ids( + ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.embed_newline_pattern_ids = tokens_to_ids( + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) + self.embed_pattern_ids = tokens_to_ids( + ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + + self.head = PoolerHead(normalize=True, softmax=False) + + def _find_array(self, arr: array, target: array, start_idx: int) -> int: + """ + Find the first occurrence of target in arr starting from start_idx. + + Args: + arr: The array to search within + target: The consecutive subsequence to find + start_idx: The starting index to search from + + Returns: + int: The index of the first occurrence of target in arr. + """ + if start_idx < 0: + raise ValueError("start_idx must be non-negative") + if not target or not arr: + raise ValueError("Empty arr or target not allowed") + + target_len = len(target) + for i in range(start_idx, len(arr) - target_len + 1): + if arr[i:i + target_len] == target: + return i + return -1 + + def _get_instruction_len(self, prompt_token_ids: array) -> int: + """ + Get the length of the instruction in the prompt. + + We do a pattern matching to find the instruction in the prompt, + and then return the length of the instruction. + + The pattern matching is done using integers instead of strings + because the prompt is given as a list of token IDs. + """ + + instruction_len = 0 + + # Return no instruction in case of missing BOS token. + if prompt_token_ids[0] != self.token_ids[""]: + logger.warning("BOS token not found in prompt, " + "thus using empty string for instruction. " + "GritLM requires BOS token in prompt.") + return instruction_len + + # If user pattern is found in the prompt, that means there should be + # a newline token before the embed pattern. + embed_pattern_ids = self.embed_pattern_ids + if self._find_array(prompt_token_ids, + self.user_pattern_ids, + start_idx=1) == 1: + embed_pattern_ids = self.embed_newline_pattern_ids + + # Find the embed pattern in the prompt. + found_embed_pattern_idx = self._find_array(prompt_token_ids, + embed_pattern_ids, + start_idx=1) + + if found_embed_pattern_idx != -1: + instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) + else: + logger.warning("Query instruction not found in prompt, " + "thus using BOS token as instruction instead. " + "GritLM requires query instruction in prompt.") + instruction_len = 1 + + return instruction_len + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """ + Pool the hidden states by summing the embeddings of + non-instruction tokens. + """ + prompts_token_ids = [ + token_ids.prompt_token_ids_array + for _, token_ids in pooling_metadata.seq_data.items() + ] + + instruction_lens = torch.tensor( + [ + self._get_instruction_len(prompt_token_ids) + for prompt_token_ids in prompts_token_ids + ], + device=hidden_states.device, + ) + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + mask = torch.zeros_like(hidden_states, dtype=torch.bool) + + start_idx = 0 + for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): + end_idx = start_idx + prompt_len + mask[start_idx + instruction_len:end_idx] = True + start_idx = end_idx + + masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) + + sum_embeddings = torch.zeros(len(prompt_lens), + hidden_states.size(1), + device=hidden_states.device) + + start_idx = 0 + for i, prompt_len in enumerate(prompt_lens): + end_idx = start_idx + prompt_len + sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( + dim=0) + start_idx = end_idx + + num_non_instruction_tokens = prompt_lens - instruction_lens + mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( + 1) + + pooled_data = self.head(mean_embeddings, + pooling_metadata=pooling_metadata) + + pooled_outputs = [ + PoolingSequenceGroupOutput(data) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) + + +class GritLM(LlamaForCausalLM, SupportsV0Only): + """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. + + The class inherits from LlamaForCausalLM and provides a custom pooling + layer. + + The main difference between the pooling layer in GritLM and the one in + LlamaForCausalLM is that GritLM ignores the query instruction in the prompt + when pooling the hidden states. + + Embedding prompts should be in the following format: + - With instruction: "<|user|>\nINSTRUCTION\n<|embed|>\nPROMPT". + - Without instruction: "<|embed|>\nPROMPT". + + Generation prompts should be in the following format: + - "<|user|>\nPROMPT\n<|assistant|>\n" + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: + # Use full attention for pooling + if vllm_config.model_config.runner_type == "pooling": + hf_config = vllm_config.model_config.hf_config + hf_config.is_causal = False + + vllm_config.cache_config.sliding_window = None + + for attr in ("sliding_window", "interleaved_sliding_window"): + if hasattr(hf_config, attr): + delattr(hf_config, attr) + + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self._pooler = GritLMPooler(vllm_config.model_config) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 0000000..2d93052 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,546 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/ROCm/vllm/blob/cea7419f151cc50293a05b7fac8547f8f887c9f6/vllm/model_executor/models/grok1.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +# Default Grok1-specific constants, overridden by config values if present +DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 +DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257 +DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169 + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.config = config # Store config reference + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + attn_logits_soft_cap = max( + getattr(config, "attn_logit_softcapping", 30.0), 0.0) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + prefix=f"{prefix}.attn") + self.attn_multiplier = getattr(self.config, "attn_output_multiplier", + 1.0) if self.config else 1.0 + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + output *= self.attn_multiplier + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Check for fp8 quantization + self.use_fp8 = False + if quant_config is not None: + self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", + lambda: False)() + if not self.use_fp8 and hasattr(quant_config, "is_fp8"): + self.use_fp8 = quant_config.is_fp8 + + # Requires transformers > 4.32.0 + # Default rope_theta value if not in config + rope_theta = 10000 + self.attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + config=config) # Pass config to Grok1Attention + + # Grok1 uses "num_experts" in its config + num_experts = getattr(config, "num_experts", 8) + num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) + + self.moe_block = Grok1MoE(num_experts=num_experts, + top_k=num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, residual) + + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Post attention normalization + hidden_states = self.post_attn_norm(hidden_states) + + # MoE block with normalization + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + hidden_states = self.moe_block(hidden_states) + hidden_states = self.post_moe_norm(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class Grok1Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = getattr( + config, "embedding_multiplier_scale", + DEFAULT_EMBEDDING_MULTIPLIER_SCALE) + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + return hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # Handle Grok1-specific norm.scale naming + if "norm.scale" in name: + name = name.replace("scale", "weight") + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = Grok1Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.output_multiplier_scale = getattr( + config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + self.output_multiplier_scale) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + # Skip lm_head when tie_word_embeddings is True + skip_prefixes = (["lm_head"] + if self.config.tie_word_embeddings else None) + + loader = AutoWeightsLoader( + self, + skip_prefixes=skip_prefixes, + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py new file mode 100644 index 0000000..467b074 --- /dev/null +++ b/vllm/model_executor/models/h2ovl.py @@ -0,0 +1,549 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from collections.abc import Mapping, Sequence +from typing import Optional, Union + +import torch +from PIL import Image +from transformers import PretrainedConfig + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargs +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) +from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .intern_vit import InternVisionModel +from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, BaseInternVLProcessor, + InternVLChatModel, build_transform, + find_closest_aspect_ratio, get_internvl_target_ratios) + + +def resolve_h2ovl_min_max_num( + *, + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: bool, + use_thumbnail: bool, +) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 + + if use_thumbnail and max_dynamic_patch != 1: + max_dynamic_patch += 1 + + return min_dynamic_patch, max_dynamic_patch + + +def get_h2ovl_target_ratios( + min_num: int, + max_num: int, + *, + prior_aspect_ratio: Optional[tuple[int, int]], +) -> list[tuple[int, int]]: + target_ratios = get_internvl_target_ratios(min_num, max_num) + + # if prior_aspect_ratio is provided, filter the target ratios + if prior_aspect_ratio is not None: + target_ratios = [ + ratio for ratio in target_ratios if prior_aspect_ratio[0] % + ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ] + + return target_ratios + + +# modified to include blocks generated in second pass +def calculate_h2ovl_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int, tuple[int, int]]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height, target_aspect_ratio + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +# refactored to handle prior_aspect_ratio +def dynamic_preprocess_h2ovl( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[list[Image.Image], tuple[int, int]]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + ( + blocks, + target_width, + target_height, + target_aspect_ratio, + ) = calculate_h2ovl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images, target_aspect_ratio + + +def _preprocess_image( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + prior_aspect_ratio: Optional[tuple[int, int]], +) -> tuple[torch.Tensor, tuple[int, int]]: + target_ratios = get_h2ovl_target_ratios( + min_num, + max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + + transform = build_transform(input_size=input_size) + images, target_aspect_ratio = dynamic_preprocess_h2ovl( + image, + image_size=input_size, + use_thumbnail=use_thumbnail, + target_ratios=target_ratios, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values, target_aspect_ratio + + +# refactored to use the _preprocess_image function +def image_to_pixel_values_h2ovl( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + use_msac: bool, +) -> torch.Tensor: + # when MSAC is turned on, we need to process the image twice + if use_msac: + # first pass + pixel_values1, aspect_ratio1 = _preprocess_image( + image, + input_size=input_size, + min_num=1, + max_num=max_num, + use_thumbnail=True, + prior_aspect_ratio=None, + ) + # second pass + pixel_values2, _ = _preprocess_image( + image, + input_size=input_size, + min_num=3, + max_num=max_num, + use_thumbnail=True, + prior_aspect_ratio=aspect_ratio1, + ) + # combine pixel values + pixel_values = torch.cat( + [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0) + + else: + pixel_values, _ = _preprocess_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=use_thumbnail, + prior_aspect_ratio=None, + ) + + return pixel_values + + +class H2OVLProcessor(BaseInternVLProcessor): + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_msac: Optional[bool] = None, + ) -> None: + super().__init__( + config, + tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + + if use_msac is None: + use_msac = config.use_msac + assert isinstance(use_msac, bool) + + self.use_msac = use_msac + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) + + def resolve_min_max_num( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> tuple[int, int]: + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) + max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch + is None else max_dynamic_patch) + dynamic_image_size = (self.dynamic_image_size if dynamic_image_size + is None else dynamic_image_size) + use_thumbnail = (self.use_thumbnail + if use_thumbnail is None else use_thumbnail) + + return resolve_h2ovl_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + def resolve_target_ratios( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + prior_aspect_ratio: Optional[tuple[int, int]] = None, + override_min_num: Optional[int] = None, + ) -> list[tuple[int, int]]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + if override_min_num is not None: + min_num = override_min_num + + return get_h2ovl_target_ratios( + min_num, + max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + use_msac: Optional[bool] = None, + ) -> int: + use_msac = (self.use_msac if use_msac is None else use_msac) + + use_thumbnail = self.use_thumbnail + + if use_msac: + target_ratios_1 = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + override_min_num=1, + ) + num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios_1, + use_thumbnail=True, + ) + + target_ratios_2 = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + prior_aspect_ratio=aspect_ratio_1, + override_min_num=3, + ) + num_patches_2, _, _, _ = calculate_h2ovl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios_2, + use_thumbnail=True, + ) + + num_patches = num_patches_1 + num_patches_2 - 1 + else: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + num_patches, _, _, _ = calculate_h2ovl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + use_msac = self.use_msac if len(images) == 1 else False + + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_h2ovl( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + use_msac=use_msac, + ) for image in images + ] + + +class H2OVLProcessingInfo(BaseInternVLProcessingInfo): + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> H2OVLProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + return self.ctx.init_processor( + H2OVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + **kwargs, + ) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[H2OVLProcessor], + use_msac: Optional[bool] = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + return processor.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + use_msac=use_msac, + ) + + +class H2OVLMultiModalProcessor( + BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + if "image_num_patches" in out_mm_kwargs: + image_num_patches = out_mm_kwargs["image_num_patches"] + assert isinstance(image_num_patches, torch.Tensor) + image_num_patches = image_num_patches.tolist() + elif "image_embeds" in out_mm_kwargs: + # TODO: Use image size information in dictionary embedding inputs + # to compute num_patches (similar to Qwen2-VL) + image_num_patches = [None] * len(out_mm_kwargs["image_embeds"]) + else: + image_num_patches = [] + + num_images = len(image_num_patches) + + def get_replacement_internvl(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + feature_size = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + feature_size = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + use_msac=None if num_images == 1 else False, + ) + + num_patches = image_num_patches[item_idx] + if num_patches is not None: + assert isinstance(num_patches, int) + + return hf_processor.get_image_repl(feature_size, num_patches) + + return [ + PromptReplacement( + modality="image", + target="", + replacement=get_replacement_internvl, + ) + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + # The processor logic is different for len(images) <= 1 vs > 1 + # Since the processing cache assumes that the processor output is + # invariant of how many images are passed per prompt, we only + # perform caching for the most common case + if mm_data_items.get_count("image", strict=False) > 1: + return self._apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes, + ) + + return super()._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + H2OVLMultiModalProcessor, + info=H2OVLProcessingInfo, + dummy_inputs=BaseInternVLDummyInputsBuilder) +class H2OVLChatModel(InternVLChatModel): + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = (config.vision_config.num_hidden_layers + + vision_feature_layer + 1) + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + msg = "Monolith mode is not applicable to H2OVL" + raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/hunyuan_v1_moe.py b/vllm/model_executor/models/hunyuan_v1_moe.py new file mode 100644 index 0000000..89ca3e8 --- /dev/null +++ b/vllm/model_executor/models/hunyuan_v1_moe.py @@ -0,0 +1,897 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# coding=utf-8 +# Copyright 2024 The HunYuan team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only HunYuan model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import regex as re +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +def _get_cla_factor(config: PretrainedConfig) -> int: + if not getattr(config, "use_cla", False): + return 1 + return getattr(config, "cla_share_factor", 1) + + +class HunYuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + reduce_results=reduce_results, + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class HunYuanAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + layer_id: int = -1, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + if hasattr(config, "head_dim"): + self.head_dim = config.head_dim + elif hasattr(config, "attention_head_dim"): + self.head_dim = config.attention_head_dim + else: + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.layer_id = layer_id + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, + eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_states: Optional[tuple[torch.Tensor]] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + ori_k = k + if self.use_qk_norm: + q = self.query_layernorm( + q.view(-1, self.num_heads, self.head_dim).contiguous()) + k = self.key_layernorm( + k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + + attn_output = self.attn(q, k, v) + # For o_proj + attn_output = attn_output.view(q.shape[0], -1) + output, _ = self.o_proj(attn_output) + return output, (ori_k, v) + + +class HunYuanCrossAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + layer_id: int = -1, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + if hasattr(config, "head_dim"): + self.head_dim = config.head_dim + elif hasattr(config, "attention_head_dim"): + self.head_dim = config.attention_head_dim + else: + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.layer_id = layer_id + + self.q_proj = ColumnParallelLinear( + hidden_size, + hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER, + ) + + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, + eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_states: Optional[tuple[torch.Tensor]] = None, + ) -> torch.Tensor: + assert kv_states is not None + ori_k, v = kv_states # use last layer kv, + k = ori_k + q, _ = self.q_proj(hidden_states) + k_tmp = torch.empty_like(k) # Todo: reduant rotary embedding + q, _ = self.rotary_emb(positions, q, k_tmp) + if self.use_qk_norm: + q = self.query_layernorm( + q.view(-1, self.num_heads, self.head_dim).contiguous()) + k = self.key_layernorm( + k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + + attn_output = self.attn(q, k, v) + # For o_proj + attn_output = attn_output.view(q.shape[0], -1) + output, _ = self.o_proj(attn_output) + return output, (ori_k, v) + + +class HunYuanSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + layer_id: int = -1, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + # Get layer_id topk if config.moe_topk is a list + if isinstance(config.moe_topk, list): + assert layer_id >= 0 + assert len(config.moe_topk) > layer_id + top_k = config.moe_topk[layer_id] + else: + top_k = config.moe_topk + + # If it is moe, moe_intermediate_size is preferred + intermediate_size = config.intermediate_size + if config.moe_intermediate_size is not None: + intermediate_size = (config.moe_intermediate_size if isinstance( + config.moe_intermediate_size, int) else + config.moe_intermediate_size[layer_id]) + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=top_k, + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + reduce_results=False, + renormalize=top_k > 1, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.use_mixed_mlp_moe > 0: + # Get layer_id num_shared_expert if config.num_shared_expert is + # a list. + if isinstance(config.num_shared_expert, list): + assert layer_id >= 0 + assert len(config.num_shared_expert) > layer_id + num_shared_expert = config.num_shared_expert[layer_id] + else: + num_shared_expert = config.num_shared_expert + + self.shared_mlp = HunYuanMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size * num_shared_expert, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + else: + self.shared_mlp = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_mlp is not None: + shared_output = self.shared_mlp(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class HunYuanDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + layer_id: int = -1, + ) -> None: + super().__init__() + assert layer_id >= 0 + self.layer_id = layer_id + self.hidden_size = config.hidden_size + self.intermediate_size = (config.intermediate_size if isinstance( + config.intermediate_size, int) else + config.intermediate_size[layer_id]) + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + cla_factor = _get_cla_factor(config) + attention_type = (AttentionType.ENCODER_DECODER + if layer_id >= 0 and layer_id % cla_factor != 0 else + AttentionType.DECODER) + if attention_type == AttentionType.DECODER: + self.self_attn = HunYuanAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + layer_id=layer_id, + ) + elif attention_type == AttentionType.ENCODER_DECODER: + self.self_attn = HunYuanCrossAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + layer_id=layer_id, + ) + else: + raise RuntimeError(f"Unsupported attention type: {attention_type}") + + self.mlp = HunYuanSparseMoeBlock( + config=config, + quant_config=quant_config, + layer_id=layer_id, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_states: Optional[tuple[torch.Tensor]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states, ori_kv_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_states=kv_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual, ori_kv_states + + +@support_torch_compile +class HunYuanModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: HunYuanDecoderLayer( + config=config, + layer_id=int(prefix.split(".")[-1]), + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + cla_factor = _get_cla_factor(self.config) + prev_kv_states = None + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual, kv_states = layer( + positions, + hidden_states, + residual, + prev_kv_states, + ) + + if (getattr(self.config, "use_cla", False) + and (i - self.start_layer) % cla_factor == 0): + prev_kv_states = kv_states + else: + prev_kv_states = None + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class HunYuanMoEV1ForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.quant_config = quant_config + self.lora_config = lora_config + + self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def _split_qkv_weight(self, qkv: torch.Tensor): + num_attention_heads = self.config.num_attention_heads + num_kv_heads = getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads) + num_key_value_groups = num_attention_heads // num_kv_heads + hidden_size = self.config.hidden_size + + if hasattr(self.config, "head_dim"): + attention_head_dim = self.config.head_dim + elif hasattr(self.config, "attention_head_dim"): + attention_head_dim = self.config.attention_head_dim + else: + attention_head_dim = self.config.hidden_size // num_attention_heads + + qkv = qkv.reshape(num_kv_heads, num_key_value_groups + 2, + attention_head_dim, hidden_size) + q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) + q = q.reshape(-1, hidden_size) + k = k.reshape(-1, hidden_size) + v = v.reshape(-1, hidden_size) + return torch.concat((q, k, v)) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + cla_factor = _get_cla_factor(self.config) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + num_attention_heads = self.config.num_attention_heads + num_kv_heads = getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads) + split_params_mapping = [ + (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None), + ( + ".qkv_proj", + ".qkv_proj", + num_attention_heads + num_kv_heads * 2, + [("q", num_attention_heads), ("k", num_kv_heads), + ("v", num_kv_heads)], + self._split_qkv_weight, + ), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "gate_proj_bias" in name: + name = name.replace("gate_proj_bias", "gate_proj.bias") + if "up_proj_bias" in name: + name = name.replace("up_proj_bias", "up_proj.bias") + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name)): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + + is_found = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + # cross layer only have q_proj, skip qkv pack + if weight_name == ".q_proj": + match = re.search(r"layers\.\d+", name) + if match: + layer_id = int(match.group(0).split(".")[-1]) + if cla_factor > 1 and layer_id % cla_factor != 0: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + is_found = True + break + if is_found: + continue + + for ( + param_name, + weight_name, + den, + split_param, + func, + ) in split_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + assert loaded_weight.shape[0] % den == 0 + units = loaded_weight.shape[0] // den + + param = params_dict[name] + weight_loader = param.weight_loader + offset = 0 + for shard_id, num in split_param: + new_offset = offset + num * units + if func: + weight_loader(param, + func(loaded_weight)[offset:new_offset], + shard_id) + else: + weight_loader(param, loaded_weight[offset:new_offset], + shard_id) + offset = new_offset + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + if "mlp.gate.wg." in name: + name = name.replace("wg.", "") + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py new file mode 100644 index 0000000..4720fd0 --- /dev/null +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2Config, Idefics2VisionConfig) +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, + max_nb_patches_h * max_nb_patches_w), + fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + out = self.attn(query_states, key_states, value_states) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention(config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + Idefics2EncoderLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + + def __init__( + self, + config: Idefics2VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder( + config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder") + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + self.require_post_norm = require_post_norm + self.post_layernorm = nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) if require_post_norm else nn.Identity() + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes, + ) + encoder_outputs = self.encoder(hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # skip pooling header + if name.startswith("head."): + continue + + # post_layernorm is optional + if (name.startswith("post_layernorm.") + and not self.require_post_norm): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers."): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py new file mode 100644 index 0000000..4643468 --- /dev/null +++ b/vllm/model_executor/models/idefics3.py @@ -0,0 +1,786 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Idefics3 model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, Optional, TypedDict, Union + +import torch +from torch import nn +from transformers import (AddedToken, BatchFeature, Idefics3Config, + Idefics3ImageProcessor, Idefics3Processor) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize +# yapf conflicts with isort for this block +# yapf: disable +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalDataItems, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +# yapf: enable +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +# yapf: disable +from .idefics2_vision_model import ( + Idefics2VisionTransformer as Idefics3VisionTransformer) +# yapf: enable +from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal +from .llama import LlamaModel +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) + + +class Idefics3ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """ + Shape: `(batch_size * num_images * num_patches, + num_channels, height, width)` + """ + pixel_attention_mask: torch.Tensor + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class Idefics3ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] + + +class Idefics3ProcessingInfo(BaseProcessingInfo): + + def get_hf_processor( + self, + *, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ) -> Idefics3Processor: + if size is not None: + kwargs["size"] = size + + return self.ctx.get_hf_processor(Idefics3Processor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def _resize_output_size(self, + *, + height: int, + width: int, + max_len: Optional[int] = None, + min_len: int = 1, + max_size: Optional[int] = None) -> tuple[int, int]: + # Set default value for max_len if not provided + max_len = max(height, width) if max_len is None else max_len + aspect_ratio = width / height + + # Handle the maximum size constraint + if max_size is not None: + max_len = min(max_len, max_size) + + # Adjust dimensions according to the aspect ratio + if width >= height: + width = max_len + height = int(width / aspect_ratio) + else: + height = max_len + width = int(height * aspect_ratio) + + # Ensure both width and height are even (if needed) + height += height % 2 + width += width % 2 + + # Ensure dimensions are not smaller than the minimum length + height = max(height, min_len) + width = max(width, min_len) + + return height, width + + def _get_resize_output_image_size( + self, + *, + image_width: int, + image_height: int, + resolution_max_side: int, + ) -> tuple[int, int]: + hf_processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + max_image_size = image_processor.size['longest_edge'] + if resolution_max_side > max_image_size: + raise ValueError( + "`resolution_max_side` cannot be larger than `max_image_size`") + + height, width = image_height, image_width + + # Find the output size, when rescaling the longest edge to max_len and + # preserving the aspect ratio + height, width = self._resize_output_size(height=height, + width=width, + max_len=resolution_max_side) + return height, width + + def _get_image_feature_grid_size( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> tuple[int, int]: + if processor is None: + processor = self.get_hf_processor() + + image_processor: Idefics3ImageProcessor = processor.image_processor + + max_image_size = image_processor.max_image_size['longest_edge'] + size = image_processor.size['longest_edge'] + assert size % max_image_size == 0, ( + "`longest_edge` in image_processor's `size` must be divisible by " + "`longest_edge` in `max_image_size`, this may be caused by " + "incorrect mm_kwargs override.") + + resized_height, resized_width = self._get_resize_output_image_size( + image_width=image_width, + image_height=image_height, + resolution_max_side=size, + ) + if resized_height > max_image_size or resized_width > max_image_size: + grid_h = math.ceil(resized_height / max_image_size) + grid_w = math.ceil(resized_width / max_image_size) + else: + grid_h = grid_w = 0 + return grid_w, grid_h + + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> int: + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + return grid_w * grid_h + 1 + + # TODO: Remove after requiring transformers>=4.52 + def _get_content(self, token: Union[AddedToken, str]) -> str: + if isinstance(token, str): + return token + + return token.content + + def _get_image_token( + self, + processor: Optional[Idefics3Processor]) -> tuple[str, str, str]: + if processor is None: + processor = self.get_hf_processor() + + image_token = self._get_content(processor.image_token) + fake_image_token = self._get_content(processor.fake_image_token) + global_image_token = processor.global_image_tag + return image_token, fake_image_token, global_image_token + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> str: + if processor is None: + processor = self.get_hf_processor() + + image_token, fake_image_token, global_img_token = self._get_image_token( + processor) + image_seq_len = processor.image_seq_len + grid_placeholder = "" + + p_img = image_token * image_seq_len + global_img_placeholder = fake_image_token + global_img_token + p_img + tile_img_placeholder = fake_image_token + grid_placeholder + p_img + + grid_w, grid_h = self._get_image_feature_grid_size( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + if grid_w == 0 and grid_h == 0: + return global_img_placeholder + fake_image_token + + tiles_placeholder = list[str]() + for i in range(grid_h): + for j in range(grid_w): + placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, + n_w=j + 1) + tiles_placeholder.append(placeholder_per_tile) + # Add line break if it is the last tile in the row + if j == grid_w - 1: + tiles_placeholder.append("\n") + + return "".join([ + *tiles_placeholder, + "\n", + global_img_placeholder, + fake_image_token, + ]) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + processor: Optional[Idefics3Processor], + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + num_patches = self.get_num_patches( + image_width=image_width, + image_height=image_height, + processor=processor, + ) + + return num_patches * processor.image_seq_len + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + image_processor: Idefics3ImageProcessor = processor.image_processor + + return ImageSize( + width=image_processor.size["longest_edge"], + height=image_processor.size["longest_edge"], + ) + + +class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] + ): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token, _, _ = self.info._get_image_token(processor) + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + hf_processor = self.info.get_hf_processor() + image_processor: Idefics3ImageProcessor = hf_processor.image_processor + longest_edge = image_processor.max_image_size['longest_edge'] + + return { + "image": + self._get_dummy_images(width=longest_edge, + height=longest_edge, + num_images=num_images) + } + + +class Idefics3MultiModalProcessor( + BaseMultiModalProcessor[Idefics3ProcessingInfo]): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Text-only input not supported in composite processor + if not (images := mm_data.get("images", [])): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + mm_kwargs, + tok_kwargs, + ) + + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": images + }).get_items("image", ImageProcessorItems)) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + num_patches = [ + self.info.get_num_patches( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) for size in image_sizes + ] + processed_outputs["num_patches"] = torch.tensor(num_patches) + + # Remove the extra batch dimension + processed_outputs["pixel_values"].squeeze_(0) + processed_outputs["pixel_attention_mask"].squeeze_(0) + + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + num_patches = hf_inputs.get("num_patches", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "image", num_patches), + image_embeds=MultiModalFieldConfig.batched("image"), + num_patches=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token, _, _ = self.info._get_image_token(hf_processor) + + def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails: + images = mm_items.get_items("image", ImageProcessorItems) + + image_size = images.get_image_size(item_idx) + + image_repl = self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + ) + + return PromptUpdateDetails.select_text( + image_repl, + embed_text=image_token, + ) + + return [ + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_idefics3, + ) + ] + + +class Idefics3SimpleMLP(nn.Module): + + def __init__( + self, + config: Idefics3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + input_size = config.vision_config.hidden_size * (config.scale_factor** + 2) + output_size = config.text_config.hidden_size + self.proj = ReplicatedLinear( + input_size, + output_size, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "proj"), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, _ = self.proj(x) + return out + + +class Idefics3Connector(nn.Module): + + def __init__( + self, + config: Idefics3Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.scale_factor = config.scale_factor + self.modality_projection = Idefics3SimpleMLP( + config, + quant_config, + prefix=maybe_prefix(prefix, "modality_projection"), + ) + + def pixel_shuffle(self, + x: torch.Tensor, + scale_factor: int = 2) -> torch.Tensor: + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), + embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape( + bsz, + int(width / scale_factor), + int(height / scale_factor), + embed_dim * (scale_factor**2), + ) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), + embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: + image_hidden_states = self.pixel_shuffle(image_hidden_states, + self.scale_factor) + image_hidden_states = self.modality_projection(image_hidden_states) + return image_hidden_states + + +class Idefics3Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: Idefics3Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = self.config.text_config.vocab_size + self.vision_model = Idefics3VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model")) + self.connector = Idefics3Connector( + config, + quant_config, + prefix=maybe_prefix(prefix, "connector"), + ) + self.text_model = LlamaModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "text_model"), + ) + + self.image_seq_len = int( + ((config.vision_config.image_size // + config.vision_config.patch_size)**2) / (config.scale_factor**2)) + self.image_token_id = self.config.image_token_id + + def image_pixels_to_features( + self, + pixel_values: torch.Tensor, + pixel_attention_mask: torch.Tensor, + ) -> torch.Tensor: + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + pixel_values = pixel_values.to( + dtype=self.vision_model.embeddings.patch_embedding.weight.dtype + ) # fp16 compatibility + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum( + dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask[ + real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, + size=patch_size, + step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, + size=patch_size, + step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + return image_hidden_states + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.text_model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + + hidden_states = self.text_model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Idefics3MultiModalProcessor, + info=Idefics3ProcessingInfo, + dummy_inputs=Idefics3DummyInputsBuilder) +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.model = Idefics3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.image_token_id = self.config.image_token_id + + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + quant_config=quant_config, + ) + if self.config.text_config.tie_word_embeddings: + self.lm_head.weight = self.model.text_model.wte.weight + self.logits_processor = LogitsProcessor(config.text_config.vocab_size) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return Idefics3ImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_attention_mask = kwargs.pop("pixel_attention_mask") + if not isinstance(pixel_attention_mask, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel_attention_mask. " + f"Got type: {type(pixel_attention_mask)}") + + num_patches = kwargs.pop("num_patches") + if not isinstance(num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of num_patches. " + f"Got type: {type(num_patches)}") + + pixel_values = flatten_bn(pixel_values, concat=True) + pixel_attention_mask = flatten_bn(pixel_attention_mask, + concat=True) + num_patches = flatten_bn(num_patches, concat=True) + + return Idefics3ImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values(pixel_values), + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_pixels( + self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["pixel_values"] + pixel_attention_mask = inputs["pixel_attention_mask"] + + return self.model.image_pixels_to_features( + pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + def _process_image_input( + self, + image_input: ImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + image_features = self._process_image_pixels(image_input) + image_features = self.model.connector(image_features) + + num_patches = image_input["num_patches"] + return [ + e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) + ] + + def get_language_model(self) -> torch.nn.Module: + return self.model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + return self._process_image_input(image_input) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model.text_model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model.text_model", + connector="model.connector", + tower_model="model.vision_model") diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py new file mode 100644 index 0000000..c7b3bf1 --- /dev/null +++ b/vllm/model_executor/models/interfaces.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, MutableSequence +from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, + Union, overload, runtime_checkable) + +import torch +from torch import Tensor +from typing_extensions import Self, TypeIs + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.utils import supports_kw + +from .interfaces_base import is_pooling_model + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.model_executor.models.utils import WeightsMapper + from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + +MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]] +""" +The output embeddings must be one of the following formats: + +- A list or tuple of 2D tensors, where each tensor corresponds to + each input multimodal data item (e.g, image). +- A single 3D tensor, with the batch dimension grouping the 2D tensors. +""" + + +@runtime_checkable +class SupportsMultiModal(Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + """ + Get the placeholder text for the `i`th `modality` item in the prompt. + """ + ... + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + """ + Returns multimodal embeddings generated from multimodal kwargs + to be merged with text embeddings. + + Note: + The returned multimodal embeddings must be in the same order as + the appearances of their corresponding multimodal data item in the + input prompt. + """ + ... + + def get_language_model(self) -> torch.nn.Module: + """ + Returns the underlying language model used for text generation. + + This is typically the `torch.nn.Module` instance responsible for + processing the merged multimodal embeddings and producing hidden states + + Returns: + torch.nn.Module: The core language model component. + """ + ... + + # Only for models that support v0 chunked prefill + # TODO(ywang96): Remove this overload once v0 is deprecated + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> Tensor: + ... + + @overload + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsMultiModalType(Protocol): + supports_multimodal: Literal[True] + + +@overload +def supports_multimodal( + model: type[object]) -> TypeIs[type[SupportsMultiModal]]: + ... + + +@overload +def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: + ... + + +def supports_multimodal( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalType) + + return isinstance(model, SupportsMultiModal) + + +@runtime_checkable +class SupportsLoRA(Protocol): + """The interface required for all models that support LoRA.""" + + supports_lora: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports LoRA. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + # The `embedding_module` and `embedding_padding_modules` + # are empty by default. + embedding_modules: ClassVar[dict[str, str]] = {} + embedding_padding_modules: ClassVar[list[str]] = [] + packed_modules_mapping: ClassVar[dict[str, list[str]]] = {} + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsLoRAType(Protocol): + supports_lora: Literal[True] + + packed_modules_mapping: dict[str, list[str]] + embedding_modules: dict[str, str] + embedding_padding_modules: list[str] + + +@overload +def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: + ... + + +@overload +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: + ... + + +def supports_lora( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]: + result = _supports_lora(model) + + if not result: + lora_attrs = ( + "packed_modules_mapping", + "embedding_modules", + "embedding_padding_modules", + ) + missing_attrs = tuple(attr for attr in lora_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_lora", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_lora=True`, " + "but is missing LoRA-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all LoRA-specific attributes, " + "but does not set `supports_lora=True`.", model) + + return result + + +def _supports_lora(model: Union[type[object], object]) -> bool: + if isinstance(model, type): + return isinstance(model, _SupportsLoRAType) + + return isinstance(model, SupportsLoRA) + + +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + + def forward( + self, + *, + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[Tensor, "IntermediateTensors"]: + """ + Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when + PP rank > 0. + + Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only + for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + *, + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[Tensor, "IntermediateTensors"]: + ... + + +@overload +def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[type[object], object], +) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]: + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect + + +def _supports_pp_attributes(model: Union[type[object], object]) -> bool: + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + +def _supports_pp_inspect(model: Union[type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + return supports_kw(model_forward, "intermediate_tensors") + + +@runtime_checkable +class HasInnerState(Protocol): + """The interface required for all models that has inner state.""" + + has_inner_state: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has inner state. + Models that has inner state usually need access to the scheduler_config + for max_num_seqs, etc. True for e.g. both Mamba and Jamba. + """ + + +@runtime_checkable +class _HasInnerStateType(Protocol): + has_inner_state: ClassVar[Literal[True]] + + +@overload +def has_inner_state(model: object) -> TypeIs[HasInnerState]: + ... + + +@overload +def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: + ... + + +def has_inner_state( + model: Union[type[object], object] +) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: + if isinstance(model, type): + return isinstance(model, _HasInnerStateType) + + return isinstance(model, HasInnerState) + + +@runtime_checkable +class IsAttentionFree(Protocol): + """The interface required for all models like Mamba that lack attention, + but do have state whose size is constant wrt the number of tokens.""" + + is_attention_free: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has no attention. + Used for block manager and attention backend selection. + True for Mamba but not Jamba. + """ + + +@runtime_checkable +class _IsAttentionFreeType(Protocol): + is_attention_free: ClassVar[Literal[True]] + + +@overload +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: + ... + + +@overload +def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: + ... + + +def is_attention_free( + model: Union[type[object], object] +) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + if isinstance(model, type): + return isinstance(model, _IsAttentionFreeType) + + return isinstance(model, IsAttentionFree) + + +@runtime_checkable +class IsHybrid(Protocol): + """The interface required for all models like Jamba that have both + attention and mamba blocks, indicates that + hf_config has 'layers_block_type'""" + + is_hybrid: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has both mamba and attention blocks + , also indicates that the model's hf_config has + 'layers_block_type' """ + + +@runtime_checkable +class _IsHybridType(Protocol): + is_hybrid: ClassVar[Literal[True]] + + +@overload +def is_hybrid(model: object) -> TypeIs[IsHybrid]: + ... + + +@overload +def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: + ... + + +def is_hybrid( + model: Union[type[object], object] +) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: + if isinstance(model, type): + return isinstance(model, _IsHybridType) + + return isinstance(model, IsHybrid) + + +@runtime_checkable +class MixtureOfExperts(Protocol): + """ + Check if the model is a mixture of experts (MoE) model. + """ + + expert_weights: MutableSequence[Iterable[Tensor]] + """ + Expert weights saved in this rank. + + The first dimension is the layer, and the second dimension is different + parameters in the layer, e.g. up/down projection weights. + """ + + num_moe_layers: int + """Number of MoE layers in this model.""" + + num_expert_groups: int + """Number of expert groups in this model.""" + + num_logical_experts: int + """Number of logical experts in this model.""" + + num_physical_experts: int + """Number of physical experts in this model.""" + + num_local_physical_experts: int + """Number of local physical experts in this model.""" + + num_routed_experts: int + """Number of routed experts in this model.""" + + num_shared_experts: int + """Number of shared experts in this model.""" + + num_redundant_experts: int + """Number of redundant experts in this model.""" + + def set_eplb_state( + self, + expert_load_view: Tensor, + logical_to_physical_map: Tensor, + logical_replica_count: Tensor, + ) -> None: + """ + Register the EPLB state in the MoE model. + + Since these are views of the actual EPLB state, any changes made by + the EPLB algorithm are automatically reflected in the model's behavior + without requiring additional method calls to set new states. + + You should also collect model's `expert_weights` here instead of in + the weight loader, since after initial weight loading, further + processing like quantization may be applied to the weights. + + Args: + expert_load_view: A view of the expert load metrics tensor. + logical_to_physical_map: Mapping from logical to physical experts. + logical_replica_count: Count of replicas for each logical expert. + """ + ... + + +def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: + return isinstance(model, MixtureOfExperts) + + +@runtime_checkable +class HasNoOps(Protocol): + has_noops: ClassVar[Literal[True]] = True + + +@runtime_checkable +class _HasNoOpsType(Protocol): + has_noops: ClassVar[Literal[True]] + + +@overload +def has_noops(model: object) -> TypeIs[HasNoOps]: + ... + + +@overload +def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: + ... + + +def has_noops( + model: Union[type[object], object] +) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: + if isinstance(model, type): + return isinstance(model, _HasNoOpsType) + + return isinstance(model, HasNoOps) + + +@runtime_checkable +class SupportsCrossEncoding(Protocol): + """The interface required for all models that support cross encoding.""" + + supports_cross_encoding: ClassVar[Literal[True]] = True + + +@overload +def supports_cross_encoding( + model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]: + ... + + +@overload +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: + ... + + +def _supports_cross_encoding( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + + if isinstance(model, type): + return isinstance(model, SupportsCrossEncoding) + + return isinstance(model, SupportsCrossEncoding) + + +def supports_cross_encoding( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + return is_pooling_model(model) and _supports_cross_encoding(model) + + +def has_step_pooler(model: Union[type[object], object]) -> bool: + """Check if the model uses step pooler.""" + return is_pooling_model(model) and any( + type(module).__name__ == "StepPool" for module in model.modules()) + + +class SupportsQuant: + """The interface required for all models that support quantization.""" + + hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None + packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None + quant_config: Optional[QuantizationConfig] = None + + def __new__(cls, *args, **kwargs) -> Self: + instance = super().__new__(cls) + + # find config passed in arguments + quant_config = cls._find_quant_config(*args, **kwargs) + if quant_config is not None: + + # attach config to model for general use + instance.quant_config = quant_config + + # apply model mappings to config for proper config-model matching + # NOTE: `TransformersForCausalLM` is not supported due to how this + # class defines `hf_to_vllm_mapper` as a post-init `@property`. + # After this is fixed, get `instance.hf_to_vllm_mapper` directly + if getattr(instance, "hf_to_vllm_mapper", None) is not None: + instance.quant_config.apply_vllm_mapper( + instance.hf_to_vllm_mapper) + if getattr(instance, "packed_modules_mapping", None) is not None: + instance.quant_config.packed_modules_mapping.update( + instance.packed_modules_mapping) + + return instance + + @staticmethod + def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: + """Find quant config passed through model constructor args""" + from vllm.config import VllmConfig # avoid circular import + + args_values = list(args) + list(kwargs.values()) + for arg in args_values: + if isinstance(arg, VllmConfig): + return arg.quant_config + + if isinstance(arg, QuantizationConfig): + return arg + + return None + + +@runtime_checkable +class SupportsTranscription(Protocol): + """The interface required for all models that support transcription.""" + + supports_transcription: ClassVar[Literal[True]] = True + + @classmethod + def get_decoder_prompt(cls, language: str, task_type: str, + prompt: str) -> str: + """Get the decoder prompt for the ASR model.""" + ... + + @classmethod + def validate_language(cls, language: str) -> bool: + """Check if the model supports a specific ISO639_1 language.""" + ... + + +@overload +def supports_transcription( + model: type[object]) -> TypeIs[type[SupportsTranscription]]: + ... + + +@overload +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: + ... + + +def supports_transcription( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]: + if isinstance(model, type): + return isinstance(model, SupportsTranscription) + + return isinstance(model, SupportsTranscription) + + +@runtime_checkable +class SupportsV0Only(Protocol): + """Models with this interface are not compatible with V1 vLLM.""" + + supports_v0_only: ClassVar[Literal[True]] = True + + +@overload +def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: + ... + + +@overload +def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: + ... + + +def supports_v0_only( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: + if isinstance(model, type): + return isinstance(model, SupportsV0Only) + + return isinstance(model, SupportsV0Only) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py new file mode 100644 index 0000000..4a1ea74 --- /dev/null +++ b/vllm/model_executor/models/interfaces_base.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload, + runtime_checkable) + +import torch +import torch.nn as nn +from typing_extensions import TypeIs, TypeVar + +from vllm.logger import init_logger +from vllm.utils import supports_kw + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import PoolerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.sampling_metadata import SamplingMetadata + +logger = init_logger(__name__) + +# The type of hidden states +# Currently, T = torch.Tensor for all models except for Medusa +# which has T = list[torch.Tensor] +T = TypeVar("T", default=torch.Tensor) +T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) + +# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags +# for the base interfaces to avoid breaking OOT registration for existing models +# that don't inherit from the base interface classes + + +@runtime_checkable +class VllmModel(Protocol[T_co]): + """The interface required for all models in vLLM.""" + + def __init__( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> None: + ... + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> T_co: + ... + + +def _check_vllm_model_init(model: Union[type[object], object]) -> bool: + model_init = model.__init__ + return supports_kw(model_init, "vllm_config") + + +def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + vllm_kws = ("input_ids", "positions") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_forward, kw)) + + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its `forward` method: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +@overload +def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: + ... + + +@overload +def is_vllm_model(model: object) -> TypeIs[VllmModel]: + ... + + +def is_vllm_model( + model: Union[type[object], object], +) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: + return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + + +@runtime_checkable +class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): + """The interface required for all generative models in vLLM.""" + + def compute_logits( + self, + hidden_states: T, + sampling_metadata: "SamplingMetadata", + ) -> Optional[T]: + """Return `None` if TP rank > 0.""" + ... + + +@overload +def is_text_generation_model( + model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]: + ... + + +@overload +def is_text_generation_model( + model: object) -> TypeIs[VllmModelForTextGeneration]: + ... + + +def is_text_generation_model( + model: Union[type[object], object], +) -> Union[TypeIs[type[VllmModelForTextGeneration]], + TypeIs[VllmModelForTextGeneration]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForTextGeneration) + + return isinstance(model, VllmModelForTextGeneration) + + +@runtime_checkable +class VllmModelForPooling(VllmModel[T], Protocol[T]): + """The interface required for all pooling models in vLLM.""" + + def pooler( + self, + hidden_states: T, + pooling_metadata: "PoolingMetadata", + ) -> "PoolerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: + ... + + +@overload +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: + ... + + +def is_pooling_model( + model: Union[type[object], object], +) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForPooling) + + return isinstance(model, VllmModelForPooling) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py new file mode 100644 index 0000000..357f234 --- /dev/null +++ b/vllm/model_executor/models/intern_vit.py @@ -0,0 +1,482 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from collections.abc import Iterable +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.attention.layer import MultiHeadAttention +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +import vllm.envs as envs + + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + return pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + + def _get_position_embedding(self, H: int, W: int) -> torch.Tensor: + position_embedding = self.position_embedding + if self.num_patches == H * W: + return position_embedding + + return torch.cat( + [ + position_embedding[:, :1, :], + self._get_pos_embed(position_embedding[:, 1:, :], H, W), + ], + dim=1, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + target_dtype)) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = self._get_position_embedding(height, width) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternVisionPatchModel(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embeddings = InternVisionEmbeddings(config) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + return hidden_states + + +class InternParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim + self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads, + self.tp_size) + + self.scale = self.head_dim**-0.5 + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + self.k_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) + + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) + + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, _ = x.shape + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + if self.qk_normalization: + q, k = self._apply_qk_norm(q, k) + + out = self.attn(q, k, v) + out, _ = self.proj(out) + return out + + +class InternSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + *, + num_dummy_heads: int = 0, + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim + + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(self.embed_dim, + 3 * self.dummy_dim, + bias=config.qkv_bias) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + self.k_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + + self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(B, N, self.num_heads, self.head_dim) + k = k.view(B, N, self.num_heads, self.head_dim) + v = v.view(B, N, self.num_heads, self.head_dim) + + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) + k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + x = x.transpose(1, 2).reshape(B, N, -1) + + x = self.proj(x) + return x + + +class InternMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = self._init_attn(config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.attn") + + self.mlp = InternMLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + + def _init_attn( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + num_dummy_heads: int, + prefix: str = "", + ): + # fallback to sdpa attention if tp unavailable + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + + if (num_heads + num_dummy_heads) % tp_size == 0: + return InternParallelAttention(config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads, + prefix=prefix) + + return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) + + def forward( + self, + hidden_states: torch.Tensor, + ): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ): + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class InternVisionModel(nn.Module): + + packed_modules_mapping = { + "qkv": ["qkv"], + } + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.encoder", + ) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py new file mode 100644 index 0000000..e8549b4 --- /dev/null +++ b/vllm/model_executor/models/internlm2.py @@ -0,0 +1,455 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from functools import partial +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class InternLM2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.w2 = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.w2", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class InternLM2Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= self.tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.key_value_groups = int(self.num_heads / self.num_kv_heads) + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.wqkv = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wqkv", + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wo", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def split_qkv(self, qkv: torch.Tensor): + seq_len = qkv.shape[0] + if self.tp_size > 1: + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size + qkv = tensor_model_parallel_all_gather(qkv) + qkv = torch.split(qkv, qkv_map, dim=-1) + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] + qkv = torch.cat(qkv, dim=-1) + + qkv = qkv.view(seq_len, self.total_num_kv_heads, + self.key_value_groups + 2, self.head_dim) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) + q = q.reshape(seq_len, self.q_size * self.tp_size) + k = k.reshape(seq_len, self.kv_size * self.tp_size) + v = v.reshape(seq_len, self.kv_size * self.tp_size) + + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + return q, k, v + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.wqkv(hidden_states) + q, k, v = self.split_qkv(qkv) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.wo(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class InternLM2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "wqkv": ["wqkv"], + "gate_up_proj": ["w1", "w3"], + } + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: type[InternLM2Model] = InternLM2Model): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + self.lora_config = lora_config + + self.model = model_type(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.output = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "output")) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.output, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class InternLM2ForRewardModel(InternLM2ForCausalLM): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: type[InternLM2Model] = InternLM2Model, + ): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=model_type) + + for attr in ("output", "logits_processor"): + delattr(self, attr) + + config = vllm_config.model_config.hf_config + self.v_head = RowParallelLinear( + config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + prefix=maybe_prefix(prefix, "v_head"), + ) + + pooler_config = vllm_config.model_config.pooler_config + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.ALL, + normalize=False, + softmax=False, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + logits, _ = self.v_head(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py new file mode 100644 index 0000000..4bbb49d --- /dev/null +++ b/vllm/model_executor/models/internlm2_ve.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.internlm2 import (InternLM2Attention, + InternLM2ForCausalLM, + InternLM2MLP, InternLM2Model) +from vllm.sequence import IntermediateTensors + + +class InternLM2VEDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + ) + self.feed_forward_ve = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward_ve", + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + if visual_token_mask is not None and visual_token_mask.any(): + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + hidden_states[visual_token_mask] = self.feed_forward_ve( + hidden_states[visual_token_mask].reshape( + -1, self.hidden_size)).flatten() + if text_token_mask.any(): + hidden_states[text_token_mask] = self.feed_forward( + hidden_states[text_token_mask].reshape( + -1, self.hidden_size)).flatten() + else: + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class InternLM2VEModel(InternLM2Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=InternLM2VEDecoderLayer) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + visual_token_mask=visual_token_mask, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2VEForCausalLM(InternLM2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=InternLM2VEModel) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py new file mode 100644 index 0000000..f8b9ea2 --- /dev/null +++ b/vllm/model_executor/models/internvl.py @@ -0,0 +1,1432 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Literal, Optional, TypedDict, TypeVar, Union + +import numpy.typing as npt +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import BatchEncoding, PretrainedConfig, TensorType + +from vllm.config import VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.models.intern_vit import (InternVisionModel, + InternVisionPatchModel) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import convert_image_mode +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +IMG_START = '' +IMG_END = '' +IMG_CONTEXT = '' + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +class InternVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class InternVLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +InternVLImageInputs = Union[InternVLImagePixelInputs, + InternVLImageEmbeddingInputs] + + +class InternVLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_video * num_frames, num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class InternVLVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` + or a list of tensors of shape `(total_video_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +InternVLVideoInputs = Union[InternVLVideoPixelInputs, + InternVLVideoEmbeddingInputs] + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def build_transform(input_size: int): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + return T.Compose([ + T.Lambda(lambda img: convert_image_mode(img, 'RGB')), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def find_closest_aspect_ratio( + aspect_ratio: float, + target_ratios: list[tuple[int, int]], + *, + width: int, + height: int, + image_size: int, +) -> tuple[int, int]: + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def resolve_internvl_min_max_num( + *, + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: bool, + use_thumbnail: bool, +) -> tuple[int, int]: + min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1 + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 + + if use_thumbnail and max_dynamic_patch != 1: + max_dynamic_patch += 1 + + return min_dynamic_patch, max_dynamic_patch + + +def get_internvl_target_ratios( + min_num: int, + max_num: int, +) -> list[tuple[int, int]]: + target_ratios = {(i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) if min_num <= i * j <= max_num} + return sorted(target_ratios, key=lambda x: x[0] * x[1]) + + +def calculate_internvl_targets( + *, + orig_width: int, + orig_height: int, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> tuple[int, int, int]: + aspect_ratio = orig_width / orig_height + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, + target_ratios, + width=orig_width, + height=orig_height, + image_size=image_size, + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # add thumbnail image if num_blocks != 1 + if use_thumbnail and blocks != 1: + blocks += 1 + + return blocks, target_width, target_height + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def dynamic_preprocess_internvl( + image: Image.Image, + *, + target_ratios: list[tuple[int, int]], + image_size: int, + use_thumbnail: bool, +) -> list[Image.Image]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + blocks, target_width, target_height = calculate_internvl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False, + ) + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def image_to_pixel_values_internvl( + image: Image.Image, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_internvl_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + images = dynamic_preprocess_internvl( + image, + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + + pixel_values = torch.stack([transform(image) for image in images]) + return pixel_values + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def video_to_pixel_values_internvl( + video: npt.NDArray, + *, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, +) -> torch.Tensor: + target_ratios = get_internvl_target_ratios(min_num, max_num) + + transform = build_transform(input_size=input_size) + frames_list = list[Image.Image]() + for frame in video: + pil_frame = dynamic_preprocess_internvl( + Image.fromarray(frame, mode="RGB"), + target_ratios=target_ratios, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) + assert len(pil_frame) == 1 + frames_list.extend(pil_frame) + + pixel_values = torch.stack([transform(image) for image in frames_list]) + return pixel_values + + +class BaseInternVLProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.vision_config.image_size + patch_size: int = config.vision_config.patch_size + + if min_dynamic_patch is None: + min_dynamic_patch = config.min_dynamic_patch + assert isinstance(min_dynamic_patch, int) + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + assert isinstance(max_dynamic_patch, int) + + if dynamic_image_size is None: + dynamic_image_size = config.dynamic_image_size + assert isinstance(dynamic_image_size, bool) + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail: bool = config.use_thumbnail + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def resolve_min_max_num( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> tuple[int, int]: + min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch + is None else min_dynamic_patch) + max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch + is None else max_dynamic_patch) + dynamic_image_size = (self.dynamic_image_size if dynamic_image_size + is None else dynamic_image_size) + use_thumbnail = (self.use_thumbnail + if use_thumbnail is None else use_thumbnail) + + return resolve_internvl_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + def resolve_target_ratios( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + use_thumbnail: Optional[bool] = None, + ) -> list[tuple[int, int]]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=use_thumbnail, + ) + + return get_internvl_target_ratios(min_num, max_num) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + target_ratios = self.resolve_target_ratios( + use_thumbnail=False, # Applied in calculate_targets + ) + + num_patches, _, _ = calculate_internvl_targets( + orig_width=image_width, + orig_height=image_height, + image_size=self.image_size, + target_ratios=target_ratios, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + image_to_pixel_values_internvl( + image, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=self.use_thumbnail, + ) for image in images + ] + + def _preprocess_image( + self, + text: list[str], + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> tuple[list[str], dict[str, torch.Tensor]]: + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + + image_repl = self.get_image_repl(feature_size, num_patches) + text = [t.replace('', image_repl.full, 1) for t in text] + return text, image_inputs + + def _make_batch_input(self, + input_item: Optional[Union[Any, list[Any]]] = None): + if input_item is None: + input_item = [] + if not isinstance(input_item, list): + input_item = [input_item] + return input_item + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> Mapping[str, NestedTensors]: + text, images = [self._make_batch_input(x) for x in (text, images)] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + + text_inputs = self.tokenizer(text) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class InternVLProcessor(BaseInternVLProcessor): + """ + HF Processor for InternVLChatModel with extended video processing logic. + + Code for video processing is adapted from video example: + https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + video_token: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + tokenizer=tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + # add extra video token for video processing + self.video_token = video_token + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + @property + def video_token_id(self) -> Optional[int]: + if self.video_token is None: + return None + return self.tokenizer.get_vocab().get(self.video_token, None) + + @property + def supports_video(self) -> bool: + return self.video_token_id is not None + + def _videos_to_pixel_values_lst( + self, + videos: list[npt.NDArray], + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + min_num, max_num = self.resolve_min_max_num( + min_dynamic_patch=1, + max_dynamic_patch=1, + dynamic_image_size=dynamic_image_size, + use_thumbnail=False, # Applied in image_to_pixel_values + ) + + return [ + video_to_pixel_values_internvl( + video, + input_size=self.image_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=False, + ) for video in videos + ] + + def _preprocess_video( + self, + text: list[str], + videos: list[npt.NDArray], + dynamic_image_size: Optional[bool] = None, + ): + if len(videos) == 0 or not self.supports_video: + video_inputs = {} + else: + pixel_values_lst_video = self._videos_to_pixel_values_lst( + videos, + dynamic_image_size=dynamic_image_size, + ) + video_inputs: dict[str, NestedTensors] = { + "pixel_values_flat_video": + torch.cat(pixel_values_lst_video), + "video_num_patches": + torch.tensor([len(item) for item in pixel_values_lst_video]), + } + + for pixel_values in pixel_values_lst_video: + num_patches = pixel_values.shape[0] + + video_repl = self.get_video_repl(self.num_image_token, + num_patches, self.video_token) + text = [t.replace('